jax.make_array_from_callback

jax.make_array_from_callback#

jax.make_array_from_callback(shape, sharding, data_callback)[source]#

通过从 data_callback 中获取的数据返回一个 jax.Array

data_callback 用于获取返回的 jax.Array 的每个可寻址分片的数据。此函数必须返回具体数组,这意味着 make_array_from_callback 与 JAX 变换(如 jit()vmap())的兼容性有限。

参数:
  • shape (Shape) – jax.Array 的形状。

  • sharding (Sharding | Layout) – 一个 Sharding 实例,描述了 jax.Array 如何在设备之间布局。

  • data_callback (Callable[[Index | None], ArrayLike]) – 回调函数,接收全局数组值的索引作为输入,并返回全局数组值的对应数据。数据可以以任何类数组对象返回,例如 numpy.ndarray

返回:

通过从 data_callback 获取的数据获得的 jax.Array

返回类型:

ArrayImpl

示例

>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> input_shape = (8, 8)
>>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
...
>>> def cb(index):
...  return global_input_data[index]
...
>>> arr = jax.make_array_from_callback(input_shape, inp_sharding, cb)
>>> arr.addressable_data(0).shape
(4, 2)