jax.vmap#
- jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None)[source]#
矢量化映射。创建一个函数,该函数在参数轴上映射
fun
。- 参数:
fun (F) – 要在其他轴上映射的函数。
in_axes (int | None | Sequence[Any]) –
一个整数、None 或值序列,指定要映射的哪些输入数组轴。
如果传递给
fun
的每个位置参数都是数组,则in_axes
可以是整数、None或整数和None的元组,其长度等于传递给fun
的位置参数的数量。整数或None
表示为所有参数映射哪个数组轴(其中None
表示不映射任何轴),而元组表示为每个相应的位置参数映射哪个轴。轴整数必须在每个数组的范围[-ndim, ndim)
内,其中ndim
是相应输入数组的维度(轴)数。如果传递给
fun
的位置参数是容器(pytree)类型,则in_axes
必须是一个序列,其长度等于传递给fun
的位置参数的数量,并且对于每个参数,in_axes
的相应元素可以是一个具有匹配的pytree结构的容器,用于指定其容器元素的映射。换句话说,in_axes
必须是传递给fun
的位置参数元组的容器树前缀。有关更多详细信息,请参阅此链接:https://jax.ac.cn/en/latest/pytrees.html#applying-optional-parameters-to-pytrees必须显式提供
axis_size
,或者至少一个位置参数必须具有非None的in_axes
。所有已映射位置参数的已映射输入轴的大小必须全部相等。作为关键字传递的参数始终在其前导轴(即轴索引0)上进行映射。
请参阅下面的示例。
out_axes (Any) – 一个整数、None或(嵌套的)标准 Python 容器(元组/列表/字典),指示映射的轴应出现在输出中的位置。所有具有映射轴的输出都必须具有非None的
out_axes
规范。轴整数必须在每个输出数组的范围[-ndim, ndim)
内,其中ndim
是vmap()
函数返回的数组的维度(轴)数,比fun
返回的相应数组的维度(轴)数多一。axis_name (AxisName | None | None) – 可选,一个可散列的 Python 对象,用于标识映射的轴,以便可以应用并行集体操作。
axis_size (int | None | None) – 可选,一个整数,指示要映射的轴的大小。如果未提供,则从参数推断映射轴的大小。
spmd_axis_name (AxisName | tuple[AxisName, ...] | None | None)
- 返回:
fun
的批处理/矢量化版本,其参数对应于fun
的参数,但在由in_axes
指示的位置处具有额外的数组轴,并且返回值对应于fun
的返回值,但在由out_axes
指示的位置处具有额外的数组轴。- 返回类型:
F
例如,我们可以使用向量点积来实现矩阵-矩阵乘积
>>> import jax.numpy as jnp >>> >>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> [] >>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis) >>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
在这里,我们使用
[a,b]
来表示形状为 (a,b) 的数组。以下是一些变体>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) -> [b] (b is the mapped axis) >>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) -> [b] (b is the mapped axis) >>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
这是一个在
in_axes
中使用容器类型来指定要映射的容器元素的哪些轴的示例>>> A, B, C, D = 2, 3, 4, 5 >>> x = jnp.ones((A, B)) >>> y = jnp.ones((B, C)) >>> z = jnp.ones((C, D)) >>> def foo(tree_arg): ... x, (y, z) = tree_arg ... return jnp.dot(x, jnp.dot(y, z)) >>> tree = (x, (y, z)) >>> print(foo(tree)) [[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]] >>> from jax import vmap >>> K = 6 # batch size >>> x = jnp.ones((K, A, B)) # batch axis in different locations >>> y = jnp.ones((B, K, C)) >>> z = jnp.ones((C, D, K)) >>> tree = (x, (y, z)) >>> vfoo = vmap(foo, in_axes=((0, (1, 2)),)) >>> print(vfoo(tree).shape) (6, 2, 5)
这是一个在
in_axes
中使用容器类型的另一个示例,这次是一个字典,用于指定要映射的容器的元素>>> dct = {'a': 0., 'b': jnp.arange(5.)} >>> x = 1. >>> def foo(dct, x): ... return dct['a'] + dct['b'] + x >>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x) >>> print(out) [1. 2. 3. 4. 5.]
矢量化函数的结果可以是映射的或未映射的。例如,下面的函数返回一个对,其中第一个元素已映射,第二个元素未映射。仅对于未映射的结果,我们可以将
out_axes
指定为None
(以保持其未映射)。>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.)) (Array([4., 5.], dtype=float32), 8.0)
如果为未映射的结果指定了
out_axes
,则结果将在映射轴上进行广播。>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.)) (Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))
如果为已映射的结果指定了
out_axes
,则结果将相应地进行转置。最后,这是一个使用
axis_name
以及集体操作的示例>>> xs = jnp.arange(3. * 4.).reshape(3, 4) >>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs)) [[12. 15. 18. 21.] [12. 15. 18. 21.] [12. 15. 18. 21.]]
有关涉及集体操作的更多示例,请参阅
jax.pmap()
文档字符串。