jax.nn.initializers.uniform#
- jax.nn.initializers.uniform(scale=0.01, dtype=<class 'jax.numpy.float64'>)[source]#
构建一个初始化器,返回实数均匀分布的随机数组。
- 参数:
scale (RealNumeric) – 可选;随机分布的上界。
dtype (DTypeLikeInexact) – 可选;初始化器的默认数据类型。
- 返回值:
一个初始化器,返回其值在
[0, scale)
范围内均匀分布的数组。- 返回类型:
初始化器
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.uniform(10.0) >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32)