jax.lax.stop_gradient#

jax.lax.stop_gradient(x)[源代码]#

停止梯度计算。

在操作上,stop_gradient 是恒等函数,也就是说,它返回参数 x 不变。但是,stop_gradient 会阻止在前向或反向模式自动微分期间梯度的流动。如果存在多个嵌套的梯度计算,则 stop_gradient 会停止所有这些计算的梯度。有关这在哪些方面有用的讨论,请参阅 停止梯度

参数:

x (T) – 数组或数组的 pytree

返回:

输入值保持不变返回,但在自动微分中将被视为常量。

返回类型:

T

示例

考虑一个简单的函数,它返回输入值的平方。

>>> def f1(x):
...   return x ** 2
>>> x = jnp.float32(3.0)
>>> f1(x)
Array(9.0, dtype=float32)
>>> jax.grad(f1)(x)
Array(6.0, dtype=float32)

同样的函数,如果 x 周围使用了 stop_gradient,在正常求值的情况下,其结果将相同,但会返回零梯度,因为 x 被有效地视为一个常量。

>>> def f2(x):
...   return jax.lax.stop_gradient(x) ** 2
>>> f2(x)
Array(9.0, dtype=float32)
>>> jax.grad(f2)(x)
Array(0.0, dtype=float32)

这在 JAX 代码库中的很多地方都有使用;例如,jax.nn.softmax() 在内部通过输入的最大值来标准化输入,为了提高效率,这个最大值会被包裹在 stop_gradient 中。有关 stop_gradient 的适用性的更多讨论,请参阅 停止梯度