jax.device_get#
- jax.device_get(x)[源代码]#
将
x
传输到主机。如果
x
是一个 pytree,则各个缓冲区将并行复制。- 参数:
x (Any) – 表示要传输到主机的数组、标量、Array 或其(嵌套的)标准 Python 容器。
- 返回:
表示
x
值的数组或其(嵌套的)Python 容器。
示例
传递 Array
>>> import jax >>> x = jax.numpy.array([1., 2., 3.]) >>> jax.device_get(x) array([1., 2., 3.], dtype=float32)
传递标量(没有效果)
>>> jax.device_get(1) 1
另请参阅
device_put
device_put_sharded
device_put_replicated