jax.lax.map#

jax.lax.map(f, xs, *, batch_size=None)[源代码]#

将函数映射到前导数组轴上。

类似于 Python 的内置 map,除了输入和输出是堆叠数组的形式。考虑使用 vmap() 转换,除非你需要逐元素应用函数以减少内存使用或与其他控制流原语进行异构计算。

xs 是数组类型时,map() 的语义由这个 Python 实现给出

def map(f, xs):
  return np.stack([f(x) for x in xs])

scan() 类似,map() 是基于 JAX 原语实现的,因此许多优于 Python 循环的优点同样适用:xs 可以是任意嵌套的 pytree 类型,并且映射的计算只会被编译一次。

如果提供了 batch_size,计算将以该大小的批次执行,并使用 vmap() 进行并行化。这既可以用作 map 的性能更高的版本,也可以用作 vmap 的内存高效版本。如果轴不能被批次大小整除,则余数将在单独的 vmap 中处理,并连接到结果中。

>>> x = jnp.ones((10, 3, 4))
>>> def f(x):
...   print('inner shape:', x.shape)
...   return x + 1
>>> y = lax.map(f, x, batch_size=3)
inner shape: (3, 4)
inner shape: (3, 4)
>>> y.shape
(10, 3, 4)

在上面的示例中,“inner shape” 被打印了两次,一次是在跟踪批处理计算时,另一次是在跟踪余数计算时。

参数:
  • f – 一个 Python 函数,用于对 xs 的第一个轴或多个轴逐元素应用。

  • xs – 沿着前导轴进行映射的值。

  • batch_size (int | None) – (可选) 指定每步并行执行的批次大小的整数。

返回值:

映射后的值。