jax.device_put_sharded

jax.device_put_sharded#

jax.device_put_sharded(shards, devices)[source]#

将数组分片传输到指定的设备并形成 Array。

参数:
  • shards (Sequence[Any]) – 一系列数组、标量或(嵌套的)标准 Python 容器,表示要堆叠在一起以形成输出的分片。 shards 的长度必须等于 devices 的长度。

  • devices (Sequence[xc.Device]) – 一系列 Device 实例,表示将相应的 shards 中的分片传输到的设备。

此函数始终是异步的,即立即返回。

返回值:

一个 Array 或(嵌套的)Python 容器,表示 shards 中元素的堆叠版本,每个分片由 devices 中相应条目指定的物理设备内存支持。

参数:
  • shards (Sequence[Any])

  • devices (Sequence[xc.Device])

示例

shards 传递数组列表会导致一个分片数组,其中包含输入的堆叠版本

>>> import jax
>>> devices = jax.local_devices()
>>> x = [jax.numpy.ones(5) for device in devices]
>>> y = jax.device_put_sharded(x, devices)
>>> np.allclose(y, jax.numpy.stack(x))
True

shards 传递包含叶子节点为数组的嵌套容器对象的列表,对应于在每个叶子节点堆叠分片。这要求列表中的所有条目具有相同的树结构

>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))]
>>> y = jax.device_put_sharded(x, devices)
>>> type(y)
<class 'tuple'>
>>> y0 = jax.device_put_sharded([a for a, b in x], devices)
>>> y1 = jax.device_put_sharded([b for a, b in x], devices)
>>> np.allclose(y[0], y0)
True
>>> np.allclose(y[1], y1)
True

另请参阅

  • device_put

  • device_put_replicated