jax.nn.initializers.ones#
- jax.nn.initializers.ones(key, shape, dtype=<class 'jax.numpy.float64'>)[source]#
一个返回全为 1 的常量数组的初始化器。
key
参数被忽略。>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)
- 参数:
key (KeyArray)
shape (core.Shape)
dtype (DTypeLikeInexact)
- 返回类型: