jax.nn.initializers.orthogonal#
- jax.nn.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)[源代码]#
构建一个初始化器,该初始化器返回均匀分布的正交矩阵。
如果形状不是正方形,则矩阵将具有正交的行或列,具体取决于哪一侧较小。
- 参数:
scale (RealNumeric) – 均匀分布的上限。
column_axis (int) – 包含应为正交的列的轴。
dtype (DTypeLikeInexact) – 权重的默认 dtype。
- 返回:
一个正交初始化器。
- 返回类型:
Initializer
示例
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)