jax.device_get#

jax.device_get(x)[源代码]#

x 传输到主机。

如果 x 是一个 PyTree,则各个缓冲区将并行复制。

参数:

x (Any) – 一个数组、标量、Array 或其(嵌套的)标准 Python 容器,表示要传输到主机的数组。

返回:

一个数组或其(嵌套的)Python 容器,表示 x 的值。

示例

传递一个 Array

>>> import jax
>>> x = jax.numpy.array([1., 2., 3.])
>>> jax.device_get(x)
array([1., 2., 3.], dtype=float32)

传递一个标量(没有效果)

>>> jax.device_get(1)
1

另请参阅

  • device_put

  • device_put_sharded

  • device_put_replicated