jax.numpy.asarray#
- jax.numpy.asarray(a, dtype=None, order=None, *, copy=None, device=None)[源代码]#
将对象转换为 JAX 数组。
JAX 实现的
numpy.asarray()
。- 参数:
a (Any) – 可以转换为数组的对象。这包括 JAX 数组、NumPy 数组、Python 标量、Python 集合(如列表和元组)、具有
__array__
方法的对象以及支持 Python 缓冲区协议的对象。dtype (DTypeLike | None | None) – 可选参数,用于指定输出数组的数据类型。如果未指定,则会从输入中推断。
order (str | None | None) – 在 JAX 中未实现
copy (bool | None | None) – 可选的布尔值,用于指定复制模式。如果为 True,则始终返回副本。如果为 False,则在需要复制时会报错。默认为 None,仅在必要时复制。
device (xc.Device | Sharding | None | None) – 可选的
Device
或Sharding
,创建的数组将提交到该设备或分片策略。
- 返回:
从输入构建的 JAX 数组。
- 返回类型:
另请参阅
jax.numpy.array()
:类似于 asarray,但默认为 copy=True。jax.numpy.from_dlpack()
:从实现 dlpack 接口的对象构造 JAX 数组。jax.numpy.frombuffer()
:从实现缓冲区接口的对象构造 JAX 数组。
示例
从 Python 标量构造 JAX 数组
>>> jnp.asarray(True) Array(True, dtype=bool) >>> jnp.asarray(42) Array(42, dtype=int32, weak_type=True) >>> jnp.asarray(3.5) Array(3.5, dtype=float32, weak_type=True) >>> jnp.asarray(1 + 1j) Array(1.+1.j, dtype=complex64, weak_type=True)
从 Python 集合构造 JAX 数组
>>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array Array([1, 2, 3], dtype=int32) >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.asarray(range(5)) Array([0, 1, 2, 3, 4], dtype=int32)
从 NumPy 数组构造 JAX 数组
>>> jnp.asarray(np.linspace(0, 2, 5)) Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)
通过 Python 缓冲区接口,使用 Python 内置的
array
模块构造 JAX 数组。>>> from array import array >>> pybuffer = array('i', [2, 3, 5, 7]) >>> jnp.asarray(pybuffer) Array([2, 3, 5, 7], dtype=int32)