jax.dlpack.to_dlpack

内容

jax.dlpack.to_dlpack#

jax.dlpack.to_dlpack(x, stream=None, src_device=None, dl_device=None, max_version=None, copy=None)[source]#

返回一个 DLPack 张量,它封装了一个 Array x

参数:
  • x (Array) – 一个 Array,位于 CPU 或 GPU 上。

  • stream (int | Any | None | None) – 可选的平台相关流,用于等待缓冲区准备就绪。这对应于 stream 参数,该参数在 https://dmlc.github.io/dlpack/latest/python_spec.html 中对 __dlpack__ 的说明中进行了描述。

  • src_device (xla_client.Device | None | None) – CPU 或 GPU Device

  • dl_device (tuple[DLDeviceType, int] | None | None) – 一个由 (dl_device_type, local_hardware_id) 组成的元组,以 DLPack 格式表示,例如由 __dlpack_device__ 生成。

  • max_version (tuple[int, int] | None | None) – 消费者(即 __dlpack__ 的调用方)支持的最高 DLPack 版本,以 (major, minor) 形式的 2 元组表示。此函数不保证返回版本为 max_version 的胶囊。

  • copy (bool | None | None) – 一个布尔值,指示是否复制输入。如果 copy=True,则该函数必须始终复制。当 copy=False 时,该函数绝不能复制,并且在需要复制时必须引发错误。如果 copy=None,则该函数必须尽可能避免复制,但必要时也可以复制。

返回:

一个 DLPack PyCapsule 对象。

注意

虽然 JAX 数组始终是不可变的,但 DLPackManagedTensor 缓冲区不能标记为不可变,并且外部进程可能会就地修改它们。如果从 JAX 数组派生的 DLPack 缓冲区被修改,则在使用关联的 JAX 数组时可能会导致未定义的行为。当 JAX 最终支持 DLManagedTensorVersioned(DLPack 1.0)时,将有可能指定缓冲区是只读的。