jax.lax.map#

jax.lax.map(f, xs, *, batch_size=None)[源代码]#

将函数映射到前导数组轴上。

类似于 Python 的内置 map,但输入和输出的形式是堆叠的数组。除非您需要逐元素应用函数以减少内存使用或与其他控制流原语进行异构计算,否则请考虑使用 vmap() 转换。

xs 是数组类型时,map() 的语义由以下 Python 实现给出:

def map(f, xs):
  return np.stack([f(x) for x in xs])

scan() 类似,map() 是用 JAX 原语实现的,因此与 Python 循环相比,它具有许多相同的优势:xs 可以是任意嵌套的 PyTree 类型,并且映射的计算只会被编译一次。

如果提供了 batch_size,则计算将以该大小的批次执行,并使用 vmap() 进行并行化。这既可以作为 map 的更高性能版本,也可以作为 vmap 的内存高效版本。如果轴不能被批次大小整除,则剩余部分将在单独的 vmap 中处理,并与结果连接起来。

>>> x = jnp.ones((10, 3, 4))
>>> def f(x):
...   print('inner shape:', x.shape)
...   return x + 1
>>> y = lax.map(f, x, batch_size=3)
inner shape: (3, 4)
inner shape: (3, 4)
>>> y.shape
(10, 3, 4)

在上面的示例中,“inner shape” 被打印了两次,一次是在跟踪批处理计算时,另一次是在跟踪剩余计算时。

参数:
  • f – 一个 Python 函数,用于在 xs 的第一个轴或多个轴上按元素方式应用。

  • xs – 沿前导轴映射的值。

  • batch_size (int | None) – (可选) 指定每个步骤并行执行的批次大小的整数。

返回值:

映射后的值。