jax.dlpack.to_dlpack#
- jax.dlpack.to_dlpack(x, stream=None, src_device=None, dl_device=None, max_version=None, copy=None)[源代码]#
返回一个封装了
Array
x
的 DLPack 张量。- 参数:
stream (int | Any | None | None) – 可选的平台相关的流,用于等待直到缓冲区准备就绪。这对应于 https://dmlc.github.io/dlpack/latest/python_spec.html 中记录的
__dlpack__
的 stream 参数。src_device (xla_client.Device | None | None) – CPU 或 GPU 的
Device
。dl_device (tuple[DLDeviceType, int] | None | None) – 一个
(dl_device_type, local_hardware_id)
元组,以 DLPack 格式表示,例如由__dlpack_device__
生成。max_version (tuple[int, int] | None | None) – 消费者(即
__dlpack__
的调用者)支持的最大 DLPack 版本,以(major, minor)
的 2 元组形式表示。此函数不保证返回max_version
版本的胶囊。copy (bool | None | None) – 一个布尔值,指示是否复制输入。如果
copy=True
,则函数必须始终复制。当copy=False
时,函数绝不能复制,并且当认为需要复制时必须引发错误。如果copy=None
,则函数必须尽可能避免复制,但如果需要可以复制。
- 返回:
一个 DLPack PyCapsule 对象。
注意
虽然 JAX 数组始终是不可变的,但
DLPackManagedTensor
缓冲区不能标记为不可变,并且 JAX 外部的进程可以对其进行原地修改。如果修改了从 JAX 数组派生的 DLPack 缓冲区,则使用关联的 JAX 数组时可能会导致未定义的行为。当 JAX 最终支持DLManagedTensorVersioned
(DLPack 1.0) 时,将可以指定缓冲区是只读的。