jax.nn.initializers.constant

内容

jax.nn.initializers.constant#

jax.nn.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)[source]#

构建一个初始化器,它返回填充有常量 value 的数组。

参数:
  • value (ArrayLike) – 用于填充初始化器的常量值。

  • dtype (DTypeLikeInexact) – 可选;初始化器的默认 dtype。

返回类型:

Initializer

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[-7., -7., -7.],
       [-7., -7., -7.]], dtype=float32)