jax.custom_jvp.defjvps#
- custom_jvp.defjvps(*jvps)[source]#
用于分别为每个参数定义 JVP 的便利包装器。
此便利包装器不能与
nondiff_argnums
一起使用。- 参数:
*jvps (Callable[..., ReturnValue] | None) – 一系列函数,每个函数对应于
custom_jvp
函数的每个位置参数。每个函数都将相应的原始输入的切线值、原始输出和 ß原始输入作为参数。请参见下面的示例。- 返回值:
无。
- 返回类型:
None
示例
>>> @jax.custom_jvp ... def f(x, y): ... return jnp.sin(x) * y ... >>> f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, ... lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))