jax.lax.stop_gradient

jax.lax.stop_gradient#

jax.lax.stop_gradient(x)[source]#

停止梯度计算。

在操作上,stop_gradient 是恒等函数,即它返回参数 x 不变。但是,stop_gradient 会阻止正向或反向模式自动微分期间梯度的流动。如果有多个嵌套的梯度计算,stop_gradient 会停止所有这些梯度。

例如

>>> jax.grad(lambda x: x**2)(3.)
Array(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
Array(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
Array(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
Array(0., dtype=float32, weak_type=True)
参数:

x (T)

返回值类型:

T