jax.hessian#

jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[源代码]#

作为密集数组的 fun 的 Hessian。

参数:
  • fun (Callable) – 要计算 Hessian 的函数。其由 argnums 指定位置的参数应为数组、标量或它们的标准 Python 容器。它应返回数组、标量或它们的标准 Python 容器。

  • argnums (int | Sequence[int]) – 可选,整数或整数序列。指定要对其进行微分的位置参数(默认为 0)。

  • has_aux (bool) – 可选,布尔值。指示 fun 是否返回一个对,其中第一个元素被认为是需要微分的数学函数的输出,而第二个元素是辅助数据。默认为 False。

  • holomorphic (bool) – 可选,布尔值。指示 fun 是否保证是全纯的。默认为 False。

返回:

一个与 fun 具有相同参数的函数,该函数计算 fun 的黑塞矩阵。

返回类型:

Callable

>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.   -2.]
 [  -2. -480.]]

hessian() 是黑塞矩阵的常用定义的推广,它支持嵌套的 Python 容器(即 pytrees)作为输入和输出。jax.hessian(fun)(x) 的树结构是通过将 fun(x) 的结构与 x 的两个副本的树乘积形成树乘积而给出的。两个树结构的树乘积是通过将第一个树的每个叶子替换为第二个树的副本而形成的。例如

>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': Array([[[ 2.,  0.], [ 0.,  0.]],
                         [[ 0.,  0.], [ 0., 12.]]], dtype=float32),
             'b': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32)},
       'b': {'a': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32),
             'b': Array([[[0.      , 0.      ], [0.      , 0.      ]],
                         [[0.      , 0.      ], [0.      , 3.843624]]], dtype=float32)}}}

因此,jax.hessian(fun)(x) 的树结构中的每个叶子对应于 fun(x) 的一个叶子和 x 的一对叶子。对于 jax.hessian(fun)(x) 中的每个叶子,如果 fun(x) 的相应数组叶子的形状为 (out_1, out_2, ...),并且 x 的相应数组叶子的形状分别为 (in_1_1, in_1_2, ...)(in_2_1, in_2_2, ...),则黑塞矩阵叶子的形状为 (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)。换句话说,Python 树结构表示黑塞矩阵的块结构,块由输入和输出 pytrees 确定。

特别是,当函数输入 x 和输出 fun(x) 各自都是单个数组时,会生成一个数组(不涉及 pytrees),如上面的 g 示例所示。如果 fun(x) 的形状为 (out1, out2, ...),并且 x 的形状为 (in1, in2, ...),则 jax.hessian(fun)(x) 的形状为 (out1, out2, ..., in1, in2, ..., in1, in2, ...)。要将 pytrees 展平为 1D 向量,请考虑使用 jax.flatten_util.flatten_pytree()