jax.device_get

内容

jax.device_get#

jax.device_get(x)[source]#

x 传输到主机。

如果 x 是一个 pytree,则并行复制各个缓冲区。

参数:

x (Any) – 一个数组、标量、Array 或(嵌套的)标准 Python 容器,表示要传输到主机的数组。

返回值:

一个数组或(嵌套的)Python 容器,表示 x 的值。

示例

传递一个 Array

>>> import jax
>>> x = jax.numpy.array([1., 2., 3.])
>>> jax.device_get(x)
array([1., 2., 3.], dtype=float32)

传递一个标量(没有效果)

>>> jax.device_get(1)
1

另请参阅

  • device_put

  • device_put_sharded

  • device_put_replicated