jax.custom_jvp.defjvp#
- custom_jvp.defjvp(jvp, symbolic_zeros=False)[source]#
为该实例所表示的函数定义一个自定义 JVP 规则。
- 参数::
jvp (Callable[..., tuple[ReturnValue, ReturnValue]]) – 表示自定义 JVP 规则的 Python 可调用对象。当没有
nondiff_argnums
时,jvp
函数应接受两个参数,第一个是原始输入元组,第二个是切线输入元组。两个元组的长度都等于custom_jvp
函数的参数数量。jvp
函数应生成一对输出,其中第一个元素是原始输出,第二个元素是切线输出。输入和输出元组的元素可以是数组,或者任何嵌套的元组/列表/字典。symbolic_zeros (bool) – 布尔值,指示在与未扰动值相对应的情况下,规则是否应该传递表示静态符号零的静态符号零的对象;否则,只传递标准的 JAX 类型(例如,类似数组的类型)。将此选项设置为
True
允许 JVP 规则检测某些输入是否没有参与微分,但需要对这些对象进行特殊处理(例如,不能传递给 jax.numpy 函数)。默认值为False
。
- 返回:
返回
jvp
,以便defjvp
可以用作装饰器。- 返回类型:
Callable[…, tuple[返回值, 返回值]]
示例
>>> @jax.custom_jvp ... def f(x, y): ... return jnp.sin(x) * y ... >>> @f.defjvp ... def f_jvp(primals, tangents): ... x, y = primals ... x_dot, y_dot = tangents ... primal_out = f(x, y) ... tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot ... return primal_out, tangent_out
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))