jax.numpy.asarray

内容

jax.numpy.asarray#

jax.numpy.asarray(a, dtype=None, order=None, *, copy=None, device=None)[source]#

将对象转换为 JAX 数组。

JAX 实现 numpy.asarray()

参数:
  • a (Any) – 可转换为数组的对象。这包括 JAX 数组、NumPy 数组、Python 标量、Python 集合(如列表和元组)、具有 __array__ 方法的对象以及支持 Python 缓冲区协议的对象。

  • dtype (DTypeLike | None | None) – 可选地指定输出数组的 dtype。如果未指定,则将从输入推断。

  • order (str | None | None) – 在 JAX 中未实现

  • copy (bool | None | None) – 可选布尔值,指定复制模式。如果为 True,则始终返回副本。如果为 False,则如果需要副本则报错。默认为 None,仅在必要时复制。

  • device (xc.Device | Sharding | None | None) – 可选的 DeviceSharding,创建的数组将提交到该设备或分片。

返回:

根据输入构建的 JAX 数组。

返回类型:

数组

另请参阅

示例

从 Python 标量构建 JAX 数组

>>> jnp.asarray(True)
Array(True, dtype=bool)
>>> jnp.asarray(42)
Array(42, dtype=int32, weak_type=True)
>>> jnp.asarray(3.5)
Array(3.5, dtype=float32, weak_type=True)
>>> jnp.asarray(1 + 1j)
Array(1.+1.j, dtype=complex64, weak_type=True)

从 Python 集合构建 JAX 数组

>>> jnp.asarray([1, 2, 3])  # list of ints -> 1D array
Array([1, 2, 3], dtype=int32)
>>> jnp.asarray([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.asarray(range(5))
Array([0, 1, 2, 3, 4], dtype=int32)

从 NumPy 数组构建 JAX 数组

>>> jnp.asarray(np.linspace(0, 2, 5))
Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

通过 Python 缓冲区接口构建 JAX 数组,使用 Python 内置的 array 模块。

>>> from array import array
>>> pybuffer = array('i', [2, 3, 5, 7])
>>> jnp.asarray(pybuffer)
Array([2, 3, 5, 7], dtype=int32)