jax.device_put

内容

jax.device_put#

jax.device_put(x, device=None, *, src=None)[source]#

x 传输到 device

参数::
  • x – 数组、标量或其(嵌套)标准 Python 容器。

  • device (None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind | None) – (可选)DeviceSharding 或标准 Python 容器中的(嵌套)Sharding(必须是 x 的树前缀),表示应将 x 传输到的设备。如果已给出,则结果将提交到设备。

  • src ( | xc.Device | 分片 | 布局 | 任何 | TransferToMemoryKind | )

返回值:

位于 device 上的 x 的副本。

如果 device 参数为 None,则如果操作数已在任何设备上,此操作的行为类似于恒等函数,否则它将数据传输到默认设备,未提交。

有关数据放置的更多详细信息,请参阅 关于数据放置的常见问题解答

此函数始终是异步的,即立即返回,不会阻塞调用 Python 线程,直到任何传输完成。