jax.numpy.full

内容

jax.numpy.full#

jax.numpy.full(shape, fill_value, dtype=None, *, device=None)[source]#

创建充满指定值的数组。

JAX 实现 numpy.full().

参数::
  • shape (Any) – int 或 int 序列,指定创建的数组的形状。

  • fill_value (ArrayLike) – 标量或数组,用于填充创建的数组。

  • dtype (DTypeLike | None | None) – 创建的数组的可选 dtype;默认为填充值的 dtype。

  • 设备 (xc.Device | 分片 | | ) – (可选) DeviceSharding,创建的数组将提交到该设备。

返回值:

指定形状和数据类型的数组,如果指定了设备,则在指定设备上。

返回类型:

数组

示例

>>> jnp.full(4, 2, dtype=float)
Array([2., 2., 2., 2.], dtype=float32)
>>> jnp.full((2, 3), 0, dtype=bool)
Array([[False, False, False],
       [False, False, False]], dtype=bool)

fill_value 也可以是广播到指定形状的数组

>>> jnp.full((2, 3), fill_value=jnp.arange(3))
Array([[0, 1, 2],
       [0, 1, 2]], dtype=int32)