jax.lax.stop_gradient#
- jax.lax.stop_gradient(x)[source]#
停止梯度计算。
在操作上,
stop_gradient
是恒等函数,即它返回参数 x 不变。但是,stop_gradient
会阻止正向或反向模式自动微分期间梯度的流动。如果有多个嵌套的梯度计算,stop_gradient
会停止所有这些梯度。例如
>>> jax.grad(lambda x: x**2)(3.) Array(6., dtype=float32, weak_type=True) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) Array(0., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: x**2))(3.) Array(2., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) Array(0., dtype=float32, weak_type=True)
- 参数:
x (T)
- 返回值类型:
T