jax.numpy.from_dlpack#

jax.numpy.from_dlpack(x, /, *, device=None, copy=None)[源代码]#

通过 DLPack 构建一个 JAX 数组。

numpy.from_dlpack() 的 JAX 实现。

参数:
  • x (Any) – 一个通过 __dlpack____dlpack_device__ 方法实现 DLPack 协议的对象,或者一个在 CPU 或 GPU 上的旧版 DLPack 张量。

  • device (xc.Device | Sharding | None | None) – 一个可选的 DeviceSharding,表示返回的数组应该被放置到的单个设备。 如果给定,则结果将提交到该设备。 如果未指定,则生成的数组将解包到它最初所在的同一设备上。 将 device 设置为与 external_array 的来源不同的设备将需要复制,这意味着 copy 必须设置为 TrueNone

  • copy (bool | None | None) – 一个可选的布尔值,控制是否执行复制。 如果 copy=True,则始终执行复制,即使解包到同一设备上也是如此。 如果 copy=False,则永远不会执行复制,并且在必要时会引发错误。 当 copy=None (默认)时,如果需要进行设备传输,则可以执行复制。

返回:

输入缓冲区的 JAX 数组。

返回类型:

Array

注意

虽然 JAX 数组始终是不可变的,但 dlpack 缓冲区不能标记为不可变的,并且 JAX 外部的进程可能会就地对其进行修改。 如果 JAX 数组是从 dlpack 缓冲区构建的,而没有复制,并且源缓冲区稍后被就地修改,则在使用关联的 JAX 数组时可能会导致未定义的行为。

示例

通过 DLPack 在 NumPy 和 JAX 之间传递数据

>>> import numpy as np
>>> rng = np.random.default_rng(42)
>>> x_numpy = rng.random(4, dtype='float32')
>>> print(x_numpy)
[0.08925092 0.773956   0.6545715  0.43887842]
>>> hasattr(x_numpy, "__dlpack__")  # NumPy supports the DLPack interface
True
>>> import jax.numpy as jnp
>>> x_jax = jnp.from_dlpack(x_numpy)
>>> print(x_jax)
[0.08925092 0.773956   0.6545715  0.43887842]
>>> hasattr(x_jax, "__dlpack__")  # JAX supports the DLPack interface
True
>>> x_numpy_round_trip = np.from_dlpack(x_jax)
>>> print(x_numpy_round_trip)
[0.08925092 0.773956   0.6545715  0.43887842]