自动矢量化#

在上一节中,我们讨论了通过 jax.jit() 函数进行 JIT 编译。本笔记本将讨论 JAX 的另一个变换:通过 jax.vmap() 进行矢量化。

手动矢量化#

考虑以下计算两个一维向量卷积的简单代码

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)
Array([11., 20., 29.], dtype=float32)

假设我们希望将此函数应用于一批权重 w 到一批向量 x 上。

xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

最简单的选择是简单地在 Python 中循环遍历批次

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

这会产生正确的结果,但是效率不高。

为了有效地批处理计算,你通常需要手动重写函数以确保它以矢量化的形式完成。这实现起来并不特别困难,但确实涉及更改函数处理索引、轴和其他输入部分的方式。

例如,我们可以手动重写 convolve() 以支持跨批次维度的矢量化计算,如下所示

def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

随着函数复杂度的增加,这种重新实现可能会变得很麻烦且容易出错;幸运的是,JAX 提供了另一种方法。

自动矢量化#

在 JAX 中,jax.vmap() 变换旨在自动生成函数的这种矢量化实现

auto_batch_convolve = jax.vmap(convolve)

auto_batch_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

它是通过类似于 jax.jit() 的方式跟踪函数,并在每个输入的开头自动添加批次轴来实现的。

如果批次维度不是第一个,您可以使用 in_axesout_axes 参数来指定输入和输出中批次维度的位置。如果所有输入和输出的批次轴都相同,则它们可以是整数;否则,它们可以是列表。

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst)
Array([[11., 11.],
       [20., 20.],
       [29., 29.]], dtype=float32)

jax.vmap() 还支持仅对其中一个参数进行批处理的情况:例如,如果您希望将一组权重 w 与一批向量 x 进行卷积;在这种情况下,in_axes 参数可以设置为 None

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

组合变换#

与所有 JAX 变换一样,jax.jit()jax.vmap() 被设计为可组合的,这意味着您可以用 jit 包裹一个 vmapped 函数,或用 vmap 包裹一个 jitted 函数,并且所有内容都能正常工作

jitted_batch_convolve = jax.jit(auto_batch_convolve)

jitted_batch_convolve(xs, ws)
Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)