jax.nn.initializers.variance_scaling

jax.nn.initializers.variance_scaling#

jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)[source]#

初始化器,将其比例调整为权重张量的形状。

使用 distribution="truncated_normal"distribution="normal",从平均值为零、标准差(如果适用,则在截断后)为 \(\sqrt{\frac{scale}{n}}\) 的(截断)正态分布中抽取样本,其中 n

  • 权重张量中输入单元的数量,如果 mode="fan_in"

  • 输出单元的数量,如果 mode="fan_out",或

  • 输入和输出单元数量的平均值,如果 mode="fan_avg"

此初始化器可以使用 in_axisout_axisbatch_axis 配置,以用于一般的卷积或密集层;不在任何这些参数中的轴被认为是“感受野”(卷积核空间轴)。

distribution="truncated_normal" 时,样本的绝对值在缩放之前会被截断在 2 个标准差以内。

distribution="uniform" 时,样本从以下分布中抽取:

  • 如果 dtype 是实数,则从均匀区间中抽取。

  • 如果 dtype 是复数,则从均匀圆盘中抽取。

均值为零,标准差为 \(\sqrt{\frac{scale}{n}}\),其中 n 如上定义。

参数:
  • scale (RealNumeric) – 缩放因子(正浮点数)。

  • mode (Literal['fan_in'] | Literal['fan_out'] | Literal['fan_avg']) – "fan_in""fan_out""fan_avg" 之一。

  • distribution (Literal['truncated_normal'] | Literal['normal'] | Literal['uniform']) – 要使用的随机分布。 "truncated_normal""normal""uniform" 之一。

  • in_axis (int | Sequence[int]) – 权重数组中输入维度轴或轴序列。

  • out_axis (int | Sequence[int]) – 权重数组中输出维度轴或轴序列。

  • batch_axis (Sequence[int]) – 权重数组中应该忽略的轴或轴序列。

  • dtype (DTypeLikeInexact) – 权重的 dtype。

返回类型:

Initializer