jax.nn.initializers.orthogonal

目录

jax.nn.initializers.orthogonal#

jax.nn.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)[source]#

构建一个返回均匀分布正交矩阵的初始化器。

如果形状不是正方形,矩阵将具有正交行或列,具体取决于哪边更小。

参数:
  • **scale** (RealNumeric) – 均匀分布的上限。

  • **column_axis** (int) – 包含应正交的列的轴。

  • **dtype** (DTypeLikeInexact) – 权重的默认 dtype。

返回值:

一个正交初始化器。

返回类型:

Initializer

示例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 3.9026976e-01,  7.2495741e-01, -5.6756169e-01],
       [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]],            dtype=float32)