jax.device_put#

jax.device_put(x, device=None, *, src=None, donate=False, may_alias=None)[源代码]#

x 传输到 device

参数:
  • x – 一个数组、标量或它们的(嵌套)标准 Python 容器。

  • device (None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind | None) – (可选的)DeviceSharding,或者标准 Python 容器中(必须是 x 的树前缀)的(嵌套)Sharding,表示 x 应传输到的设备。如果给出,则结果将提交到设备。

  • src (None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind | None) – (可选的)DeviceSharding,或者标准 Python 容器中(必须是 x 的树前缀)的(嵌套)Sharding,表示 x 所在的设备。

  • donate (bool | Any) – 布尔值或标准 Python 容器中的(嵌套)布尔值(必须是 x 的树前缀)。如果为 True,则可以覆盖 x 并在调用方中标记为删除。这是尽力而为。JAX 会尽可能捐赠,否则不会。如果捐赠,输入缓冲区(将来)始终会被删除。

  • may_alias (bool | None | Any | None) – 布尔值或 None 或标准 Python 容器中的(嵌套)布尔值(必须是 x 的树前缀)。如果为 False,则会复制 x。如果为 True,则 x 可能会根据运行时的实现进行别名。

返回:

位于 device 上的 x 的副本。

如果 device 参数为 None,则如果操作数已在任何设备上,则此操作的行为类似于恒等函数,否则它会将数据传输到默认设备,不提交。

有关数据放置的更多详细信息,请参阅 关于数据放置的常见问题

此函数始终是异步的,即立即返回,而不会阻塞调用 Python 线程,直到任何传输完成。