jax.make_array_from_callback#
- jax.make_array_from_callback(shape, sharding, data_callback)[source]#
通过从
data_callback
中获取的数据返回一个jax.Array
。data_callback
用于获取返回的jax.Array
的每个可寻址分片的数据。此函数必须返回具体数组,这意味着make_array_from_callback
与 JAX 变换(如jit()
或vmap()
)的兼容性有限。- 参数:
shape (Shape) –
jax.Array
的形状。sharding (Sharding | Layout) – 一个
Sharding
实例,描述了jax.Array
如何在设备之间布局。data_callback (Callable[[Index | None], ArrayLike]) – 回调函数,接收全局数组值的索引作为输入,并返回全局数组值的对应数据。数据可以以任何类数组对象返回,例如
numpy.ndarray
。
- 返回:
通过从
data_callback
获取的数据获得的jax.Array
。- 返回类型:
ArrayImpl
示例
>>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> input_shape = (8, 8) >>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape) >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) ... >>> def cb(index): ... return global_input_data[index] ... >>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb) >>> arr.addressable_data(0).shape (4, 2)