jax.make_array_from_callback#
- jax.make_array_from_callback(shape, sharding, data_callback)[源代码]#
通过从
data_callback
获取的数据返回一个jax.Array
。data_callback
用于获取返回的jax.Array
中每个可寻址分片的数据。此函数必须返回具体的数组,这意味着make_array_from_callback
与 JAX 转换(如jit()
或vmap()
)的兼容性有限。- 参数:
- 返回:
通过从
data_callback
获取的数据得到的jax.Array
。- 返回类型:
ArrayImpl
示例
>>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> input_shape = (8, 8) >>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape) >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) ... >>> def cb(index): ... return global_input_data[index] ... >>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb) >>> arr.addressable_data(0).shape (4, 2)