jax.numpy.zeros

内容

jax.numpy.zeros#

jax.numpy.zeros(shape, dtype=None, *, device=None)[source]#

创建全为零的数组。

numpy.zeros() 的 JAX 实现。

参数:
  • shape (Any) – int 或 int 序列,指定创建数组的形状。

  • dtype (DTypeLike | None | None) – 创建数组的可选数据类型;默认为浮点数。

  • device (xc.Device | Sharding | None | None) – (可选)DeviceSharding,创建的数组将被提交到该设备。

返回值:

指定形状和数据类型的数组,如果指定了设备,则在指定设备上。

返回类型:

数组

示例

>>> jnp.zeros(4)
Array([0., 0., 0., 0.], dtype=float32)
>>> jnp.zeros((2, 3), dtype=bool)
Array([[False, False, False],
       [False, False, False]], dtype=bool)