jax.device_put_replicated

jax.device_put_replicated#

jax.device_put_replicated(x, devices)[source]#

将数组传输到每个指定的设备并形成 Array。

参数:
  • x (Any) – 表示要复制以形成输出的数组的数组、标量或其(嵌套)标准 Python 容器。

  • devices (Sequence[xc.Device]) – Device 实例的序列,表示将 x 传输到的设备。

此函数始终是异步的,即立即返回。

返回值:

一个 Array 或其(嵌套)Python 容器,表示沿大小为 len(devices) 的新前导轴广播的 x 的值,沿该新前导轴的每个切片都由 devices 中对应条目指定的设备上的内存支持。

参数:
  • x (Any)

  • devices (Sequence[xc.Device])

示例

传递数组

>>> import jax
>>> devices = jax.local_devices()
>>> x = jax.numpy.array([1., 2., 3.])
>>> y = jax.device_put_replicated(x, devices)
>>> np.allclose(y, jax.numpy.stack([x for _ in devices]))
True

另请参阅

  • device_put

  • device_put_sharded