jax.numpy.ones#

jax.numpy.ones(shape, dtype=None, *, device=None)[源代码]#

创建一个充满 1 的数组。

JAX 实现的 numpy.ones()

参数:
  • shape (Any) – 指定创建数组形状的 int 或 int 序列。

  • dtype (DTypeLike | None | None) – 创建数组的可选数据类型;默认为浮点数。

  • device (xc.Device | Sharding | None | None) – (可选) DeviceSharding,表示创建的数组将被提交到的设备。

返回:

具有指定形状和数据类型,如果指定了设备,则位于指定设备上的数组。

返回类型:

数组

示例

>>> jnp.ones(4)
Array([1., 1., 1., 1.], dtype=float32)
>>> jnp.ones((2, 3), dtype=bool)
Array([[ True,  True,  True],
       [ True,  True,  True]], dtype=bool)