jax.device_get#
- jax.device_get(x)[source]#
将
x
传输到主机。如果
x
是一个 pytree,则并行复制各个缓冲区。- 参数:
x (Any) – 一个数组、标量、Array 或(嵌套的)标准 Python 容器,表示要传输到主机的数组。
- 返回值:
一个数组或(嵌套的)Python 容器,表示
x
的值。
示例
传递一个 Array
>>> import jax >>> x = jax.numpy.array([1., 2., 3.]) >>> jax.device_get(x) array([1., 2., 3.], dtype=float32)
传递一个标量(没有效果)
>>> jax.device_get(1) 1
另请参阅
device_put
device_put_sharded
device_put_replicated