jax.grad#
- jax.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
创建一个函数,该函数计算
fun
的梯度。- 参数:
**fun** (Callable) – 要微分的函数。其参数在由
argnums
指定的位置应为数组、标量或标准 Python 容器。由argnums
指定的位置处的参数数组必须为非精确(即浮点或复数)类型。它应该返回一个标量(其中包括形状为()
的数组,但不包括形状为(1,)
等的数组)argnums (int | Sequence[int]) – 可选,整数或整数序列。指定要对其求导的位置参数(默认值为 0)。
has_aux (bool) – 可选,布尔值。指示
fun
是否返回一对值,其中第一个元素被视为要微分的数学函数的输出,而第二个元素是辅助数据。默认为 False。holomorphic (bool) – 可选,布尔值。指示
fun
是否承诺为全纯函数。如果为 True,则输入和输出必须为复数。默认为 False。allow_int (bool) – 可选,布尔值。是否允许针对整数值输入求导。整数输入的梯度将具有一个微不足道的向量空间类型(float0)。默认为 False。
reduce_axes (Sequence[AxisName])
- 返回:
一个与
fun
具有相同参数的函数,该函数计算fun
的梯度。如果argnums
是一个整数,则梯度具有与该整数指示的位置参数相同的形状和类型。如果 argnums 是一个整数元组,则梯度是具有与相应参数相同形状和类型的值的元组。如果has_aux
为 True,则返回一对 (梯度,辅助数据)。- 返回类型:
可调用对象
例如
>>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043