jax.custom_gradient#
- jax.custom_gradient(fun)[source]#
定义自定义 VJP 规则(又名自定义梯度)的便捷函数。
虽然定义自定义 VJP 规则的规范方法是通过
jax.custom_vjp
,但custom_gradient
方便包装器遵循 TensorFlow 的tf.custom_gradient
API。这里的区别在于custom_gradient
可以用作一个装饰器,作用于一个返回原始值(表示要微分的数学函数的输出)和 VJP(梯度)函数的函数。请参阅 https://tensorflowcn.cn/api_docs/python/tf/custom_gradient。如果要微分的数学函数具有类似 Haskell 的签名
a -> b
,则 Python 可调用对象fun
应该具有签名a -> (b, CT b --o CT a)
,其中我们使用CT x
表示x
的余切类型,并使用--o
箭头表示线性函数。请参见下面的示例。也就是说,fun
应该返回一个对,其中第一个元素表示要微分的数学函数的值,第二个元素是一个函数,将在反向模式自动微分的反向传递中调用(即“自定义梯度”函数)。作为
fun
输出的第二个元素返回的函数可以闭包在计算要微分的函数时计算的中间值。也就是说,使用词法闭包在反向模式自动微分的正向传递和反向传递之间共享工作。但是,它不能执行依赖于闭包的中间值或其余切参数值的 Python 控制流;如果函数包含此类控制流,则会引发错误。- 参数:
fun – 一个 Python 可调用对象,指定要微分的数学函数及其反向模式微分规则。它应该返回一个对,该对由一个输出值和一个表示自定义梯度函数的 Python 可调用对象组成。
- 返回值:
一个 Python 可调用对象,它接受与
fun
相同的参数,并返回由fun
输出对的第一个元素指定的值。
例如
>>> @jax.custom_gradient ... def f(x): ... return x ** 2, lambda g: (g * x,) ... >>> print(f(3.)) 9.0 >>> print(jax.grad(f)(3.)) 3.0
一个具有两个参数的函数的示例,因此 VJP 函数必须返回一个长度为二的元组
>>> @jax.custom_gradient ... def f(x, y): ... return x * y, lambda g: (g * y, g * x) ... >>> print(f(3., 4.)) 12.0 >>> print(jax.grad(f, argnums=(0, 1))(3., 4.)) (Array(4., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))