jax.device_get#

jax.device_get(x)[源代码]#

x 传输到主机。

如果 x 是一个 pytree,则各个缓冲区将并行复制。

参数:

x (Any) – 表示要传输到主机的数组、标量、Array 或其(嵌套的)标准 Python 容器。

返回:

表示 x 值的数组或其(嵌套的)Python 容器。

示例

传递 Array

>>> import jax
>>> x = jax.numpy.array([1., 2., 3.])
>>> jax.device_get(x)
array([1., 2., 3.], dtype=float32)

传递标量(没有效果)

>>> jax.device_get(1)
1

另请参阅

  • device_put

  • device_put_sharded

  • device_put_replicated