jax.custom_vjp#

class jax.custom_vjp(fun, nondiff_argnums=())[源代码]#

为自定义 VJP 规则定义设置 JAX 可转换函数。

这个类旨在用作函数装饰器。实例是可调用的,其行为类似于应用该装饰器的底层函数,但当应用反向模式微分转换(如 jax.grad())时除外,在这种情况下,会使用用户提供的自定义 VJP 规则函数,而不是跟踪并对底层函数的实现进行自动微分。有一个单实例方法 defvjp(),可用于定义自定义 VJP 规则。

此装饰器排除了使用前向模式自动微分。

例如

@jax.custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)

有关更详细的介绍,请参阅教程

参数:
  • fun (Callable[..., ReturnValue])

  • nondiff_argnums (Sequence[int])

__init__(fun, nondiff_argnums=())[源代码]#
参数:
  • fun (Callable[..., ReturnValue])

  • nondiff_argnums (Sequence[int])

方法

__init__(fun[, nondiff_argnums])

defvjp(fwd, bwd[, symbolic_zeros, ...])

为该实例表示的函数定义自定义 VJP 规则。