jax.custom_batching.custom_vmap#

class jax.custom_batching.custom_vmap(fun)[源代码]#

自定义 JAX 可转换函数在 vmap 下的行为。

此装饰器用于自定义 JAX 函数在 jax.vmap() 转换下的行为。一个用 custom_vmap 装饰的函数,除了在使用 jax.vmap() 进行批处理时,其余时候(关于注意事项请看下面)行为与其底层函数基本相同。当进行批处理时,将使用 def_vmap() 定义的规则。

例如

>>> @jax.custom_batching.custom_vmap
... def f(x, y):
...   return x + y
...
>>> @f.def_vmap
... def f_vmap_rule(axis_size, in_batched, xs, ys):
...   assert all(in_batched)
...   assert xs.shape[0] == axis_size
...   assert ys.shape[0] == axis_size
...   out_batched = True
...   return xs * ys, out_batched
...
>>> xs = jnp.arange(3)
>>> ys = jnp.arange(1, 4)
>>> jax.vmap(f)(xs, ys)  # prints xs * ys instead of xs + ys
Array([0, 2, 6], dtype=int32)

值得注意的是,custom_vmap 函数不支持反向模式自动微分。要自定义 vmap 和反向模式自动微分,请将 custom_vmapjax.custom_vjp 结合使用。例如

>>> @jax.custom_vjp
... @jax.custom_batching.custom_vmap
... def f(x, y):
...   return jnp.sin(x) * y
...
>>> @f.def_vmap
... def f_vmap_rule(axis_size, in_batched, xs, ys):
...   return jnp.cos(xs) * ys, True
...
>>> def f_fwd(x, y):
...   return f(x, y), (jnp.cos(x), jnp.sin(x), y)
...
>>> def f_bwd(res, g):
...   cos_x, sin_x, y = res
...   return (cos_x * g * y, sin_x * g)
...
>>> f.defvjp(f_fwd, f_bwd)
>>> jax.vmap(f)(jnp.zeros(3), jnp.ones(3))
Array([1., 1., 1.], dtype=float32)
>>> jax.grad(f)(jnp.zeros(()), jnp.ones(()))
Array(1., dtype=float32)

请注意,jax.custom_vjp 必须在外层,包裹着 custom_vmap 装饰的函数。

参数:

fun (Callable[..., Any])

__init__(fun)[源代码]#
参数:

fun (Callable[..., Any])

方法

__init__(fun)

def_vmap(vmap_rule)

为此 custom_vmap 函数定义 vmap 规则。

属性

fun

vmap_rule