jax.custom_batching.custom_vmap#
- class jax.custom_batching.custom_vmap(fun)[源代码]#
自定义 JAX 可转换函数在 vmap 下的行为。
此装饰器用于自定义 JAX 函数在
jax.vmap()
转换下的行为。一个用custom_vmap
装饰的函数,除了在使用jax.vmap()
进行批处理时,其余时候(关于注意事项请看下面)行为与其底层函数基本相同。当进行批处理时,将使用def_vmap()
定义的规则。例如
>>> @jax.custom_batching.custom_vmap ... def f(x, y): ... return x + y ... >>> @f.def_vmap ... def f_vmap_rule(axis_size, in_batched, xs, ys): ... assert all(in_batched) ... assert xs.shape[0] == axis_size ... assert ys.shape[0] == axis_size ... out_batched = True ... return xs * ys, out_batched ... >>> xs = jnp.arange(3) >>> ys = jnp.arange(1, 4) >>> jax.vmap(f)(xs, ys) # prints xs * ys instead of xs + ys Array([0, 2, 6], dtype=int32)
值得注意的是,
custom_vmap
函数不支持反向模式自动微分。要自定义 vmap 和反向模式自动微分,请将custom_vmap
与jax.custom_vjp
结合使用。例如>>> @jax.custom_vjp ... @jax.custom_batching.custom_vmap ... def f(x, y): ... return jnp.sin(x) * y ... >>> @f.def_vmap ... def f_vmap_rule(axis_size, in_batched, xs, ys): ... return jnp.cos(xs) * ys, True ... >>> def f_fwd(x, y): ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) ... >>> def f_bwd(res, g): ... cos_x, sin_x, y = res ... return (cos_x * g * y, sin_x * g) ... >>> f.defvjp(f_fwd, f_bwd) >>> jax.vmap(f)(jnp.zeros(3), jnp.ones(3)) Array([1., 1., 1.], dtype=float32) >>> jax.grad(f)(jnp.zeros(()), jnp.ones(())) Array(1., dtype=float32)
请注意,
jax.custom_vjp
必须在外层,包裹着custom_vmap
装饰的函数。- 参数:
fun (Callable[..., Any])
方法
__init__
(fun)def_vmap
(vmap_rule)为此 custom_vmap 函数定义 vmap 规则。
属性
fun
vmap_rule