jax.device_put_sharded#
- jax.device_put_sharded(shards, devices)[source]#
将数组分片传输到指定的设备并形成 Array。
- 参数:
shards (Sequence[Any]) – 一系列数组、标量或(嵌套的)标准 Python 容器,表示要堆叠在一起以形成输出的分片。
shards
的长度必须等于devices
的长度。devices (Sequence[xc.Device]) – 一系列
Device
实例,表示将相应的shards
中的分片传输到的设备。
此函数始终是异步的,即立即返回。
- 返回值:
一个 Array 或(嵌套的)Python 容器,表示
shards
中元素的堆叠版本,每个分片由devices
中相应条目指定的物理设备内存支持。- 参数:
shards (Sequence[Any])
devices (Sequence[xc.Device])
示例
为
shards
传递数组列表会导致一个分片数组,其中包含输入的堆叠版本>>> import jax >>> devices = jax.local_devices() >>> x = [jax.numpy.ones(5) for device in devices] >>> y = jax.device_put_sharded(x, devices) >>> np.allclose(y, jax.numpy.stack(x)) True
为
shards
传递包含叶子节点为数组的嵌套容器对象的列表,对应于在每个叶子节点堆叠分片。这要求列表中的所有条目具有相同的树结构>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))] >>> y = jax.device_put_sharded(x, devices) >>> type(y) <class 'tuple'> >>> y0 = jax.device_put_sharded([a for a, b in x], devices) >>> y1 = jax.device_put_sharded([b for a, b in x], devices) >>> np.allclose(y[0], y0) True >>> np.allclose(y[1], y1) True
另请参阅
device_put
device_put_replicated