jax.lax.map

内容

jax.lax.map#

jax.lax.map(f, xs, *, batch_size=None)[source]#

在领先的数组轴上映射函数。

类似于 Python 的内置 map,除了输入和输出采用堆叠数组的形式。考虑使用 vmap() 转换,除非您需要逐元素应用函数以减少内存使用或使用其他控制流原语进行异构计算。

xs 为数组类型时,map() 的语义由以下 Python 实现给出

def map(f, xs):
  return np.stack([f(x) for x in xs])

类似于 scan()map() 是用 JAX 原语实现的,因此与 Python 循环相比,许多相同的优点适用:xs 可以是任意嵌套的 pytree 类型,并且映射的计算只编译一次。

如果提供了 batch_size,则计算将以该大小的批次执行并使用 vmap() 并行化。这可以用作 map 的更高性能版本或 vmap 的内存高效版本。如果轴不能被批次大小整除,则剩余部分将在单独的 vmap 中处理并连接到结果。

>>> x = jnp.ones((10, 3, 4))
>>> def f(x):
...   print('inner shape:', x.shape)
...   return x + 1
>>> y = lax.map(f, x, batch_size=3)
inner shape: (3, 4)
inner shape: (3, 4)
>>> y.shape
(10, 3, 4)

在上面的示例中,“inner shape” 打印了两次,一次是在跟踪批量计算时,另一次是在跟踪剩余计算时。

参数:
  • f – 一个 Python 函数,用于对 xs 的第一个轴或多个轴进行逐元素应用。

  • xs – 沿着前导轴映射的值。

  • batch_size (int | None) – (可选)指定每个步骤并行执行的批次大小的整数。

返回值:

映射后的值。