jax.vmap#

jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)[源代码]#

向量化映射。创建一个函数,该函数将 fun 映射到参数轴上。

参数
  • fun (F) – 要映射到额外轴上的函数。

  • in_axes (int | None | Sequence[Any]) –

    一个整数、None 或指定要映射的输入数组轴的值序列。

    如果 fun 的每个位置参数都是一个数组,则 in_axes 可以是一个整数、None 或一个整数和 None 的元组,其长度等于 fun 的位置参数的数量。一个整数或 None 指示要为所有参数映射的数组轴(None 指示不映射任何轴),而元组指示要为每个对应的位置参数映射的轴。轴整数必须在每个数组的范围 [-ndim, ndim) 内,其中 ndim 是相应输入数组的维度(轴)数。

    如果 fun 的位置参数是容器(pytree)类型,则 in_axes 必须是一个长度等于 fun 的位置参数数量的序列,并且对于每个参数,in_axes 的相应元素可以是具有匹配 pytree 结构的容器,指定其容器元素的映射。换句话说,in_axes 必须是传递给 fun 的位置参数元组的容器树前缀。有关更多详细信息,请参见此链接:https://jax.ac.cn/en/latest/pytrees.html#applying-optional-parameters-to-pytrees

    必须显式提供 axis_size,或者至少一个位置参数的 in_axes 不为 None。所有映射位置参数的映射输入轴的大小必须全部相等。

    作为关键字传递的参数始终映射到它们的领先轴(即轴索引 0)。

    请参见下面的示例。

  • out_axes (Any) – 一个整数、None 或(嵌套的)标准 Python 容器(元组/列表/字典),指示映射的轴应出现在输出中的位置。具有映射轴的所有输出都必须具有非 None 的 out_axes 规范。轴整数必须在每个输出数组的范围 [-ndim, ndim) 内,其中 ndimvmap()-ed 函数返回的数组的维度(轴)数,该数比 fun 返回的相应数组的维度(轴)数多 1。

  • axis_name (AxisName | None | None) – 可选,一个可哈希的 Python 对象,用于标识映射的轴,以便可以应用并行集合。

  • axis_size (int | None | None) – 可选,一个整数,指示要映射的轴的大小。如果未提供,则从参数推断映射的轴大小。

  • spmd_axis_name (AxisName | tuple[AxisName, ...] | None | None)

返回

fun 的批量/向量化版本,其参数与 fun 的参数相对应,但在 in_axes 指示的位置具有额外的数组轴,并且返回值与 fun 的返回值相对应,但在 out_axes 指示的位置具有额外的数组轴。

返回类型

F

例如,我们可以使用向量点积来实现矩阵-矩阵乘积

>>> import jax.numpy as jnp
>>>
>>> vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)

在这里,我们使用 [a,b] 来表示形状为 (a,b) 的数组。以下是一些变体

>>> mv1 = vmap(vv, (0, 0), 0)   #  ([b,a], [b,a]) -> [b]        (b is the mapped axis)
>>> mv2 = vmap(vv, (0, 1), 0)   #  ([b,a], [a,b]) -> [b]        (b is the mapped axis)
>>> mm2 = vmap(mv2, (1, 1), 0)  #  ([b,c,a], [a,c,b]) -> [c,b]  (c is the mapped axis)

以下是一个在 in_axes 中使用容器类型来指定要映射的容器元素的轴的示例

>>> A, B, C, D = 2, 3, 4, 5
>>> x = jnp.ones((A, B))
>>> y = jnp.ones((B, C))
>>> z = jnp.ones((C, D))
>>> def foo(tree_arg):
...   x, (y, z) = tree_arg
...   return jnp.dot(x, jnp.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
 [12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6  # batch size
>>> x = jnp.ones((K, A, B))  # batch axis in different locations
>>> y = jnp.ones((B, K, C))
>>> z = jnp.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree).shape)
(6, 2, 5)

以下是另一个在 in_axes 中使用容器类型的示例,这次是字典,用于指定要映射的容器元素

>>> dct = {'a': 0., 'b': jnp.arange(5.)}
>>> x = 1.
>>> def foo(dct, x):
...  return dct['a'] + dct['b'] + x
>>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x)
>>> print(out)
[1. 2. 3. 4. 5.]

向量化函数的结果可以映射或取消映射。例如,以下函数返回一对,其中第一个元素已映射,第二个元素未映射。只有对于未映射的结果,我们才能将 out_axes 指定为 None(以保持其未映射)。

>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), 8.0)

如果为未映射的结果指定了 out_axes,则该结果将在映射的轴上广播

>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))

如果为映射的结果指定了 out_axes,则该结果会相应地转置。

最后,以下是一个使用 axis_name 与集合的示例

>>> xs = jnp.arange(3. * 4.).reshape(3, 4)
>>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
[[12. 15. 18. 21.]
 [12. 15. 18. 21.]
 [12. 15. 18. 21.]]

有关涉及集合的更多示例,请参见 jax.pmap() 文档字符串。