jax.custom_gradient#

jax.custom_gradient(fun)[源代码]#

用于定义自定义 VJP 规则(又名自定义梯度)的便捷函数。

虽然定义自定义 VJP 规则的规范方法是通过 jax.custom_vjp,但 custom_gradient 便捷包装器遵循 TensorFlow 的 tf.custom_gradient API。 此处的区别在于 custom_gradient 可以用作一个函数的装饰器,该函数返回原始值(表示要微分的数学函数的输出)和 VJP(梯度)函数。 请参阅 https://tensorflowcn.cn/api_docs/python/tf/custom_gradient

如果要微分的数学函数具有类似 Haskell 的签名 a -> b,那么 Python 可调用对象 fun 应该具有签名 a -> (b, CT b --o CT a),其中我们使用 CT x 来表示 x 的余切类型,并使用 --o 箭头来表示线性函数。请参见下面的示例。也就是说,fun 应该返回一个对,其中第一个元素表示要微分的数学函数的值,第二个元素是一个在反向模式自动微分的反向传播中被调用的函数(即“自定义梯度”函数)。

作为 fun 输出的第二个元素返回的函数可以闭包在计算要微分的函数时计算的中间值上。也就是说,使用词法闭包来共享反向模式自动微分的正向传播和反向传播之间的工作。但是,它不能执行依赖于闭包的中间值或其余切参数值的 Python 控制流;如果该函数包含此类控制流,则会引发错误。

参数:

fun – 一个 Python 可调用对象,指定要微分的数学函数及其反向模式微分规则。它应该返回一个由输出值和表示自定义梯度函数的 Python 可调用对象组成的对。

返回:

一个 Python 可调用对象,它接受与 fun 相同的参数,并返回 fun 的输出对的第一个元素指定的输出值。

例如

>>> @jax.custom_gradient
... def f(x):
...   return x ** 2, lambda g: (g * x,)
...
>>> print(f(3.))
9.0
>>> print(jax.grad(f)(3.))
3.0

一个带有两个参数的函数的示例,因此 VJP 函数必须返回一个长度为二的元组

>>> @jax.custom_gradient
... def f(x, y):
...   return x * y, lambda g: (g * y, g * x)
...
>>> print(f(3., 4.))
12.0
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
(Array(4., dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))