jax.hessian

内容

jax.hessian#

jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

作为密集数组的 fun 的 Hessian。

参数:
  • fun (Callable) – 要计算其 Hessian 的函数。其参数在由 argnums 指定的位置应为数组、标量或其标准 Python 容器。它应返回数组、标量或其标准 Python 容器。

  • argnums (int | Sequence[int]) – 可选,整数或整数序列。指定要相对于其进行微分的 positional 参数(默认值为 0)。

  • has_aux (bool) – 可选,布尔值。指示 fun 是否返回一对,其中第一个元素被认为是将要微分的数学函数的输出,而第二个元素是辅助数据。默认值为 False。

  • 全纯 (布尔) – 可选,布尔值。指示 fun 是否保证为全纯函数。默认值为 False。

返回值:

一个与 fun 具有相同参数的函数,用于评估 fun 的 Hessian 矩阵。

返回类型:

可调用对象

>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.   -2.]
 [  -2. -480.]]

hessian() 是 Hessian 矩阵通常定义的泛化,支持嵌套的 Python 容器(即 pytrees)作为输入和输出。jax.hessian(fun)(x) 的树结构是通过形成 fun(x) 结构与两个 x 结构的树积而得到的。两个树结构的树积是通过用第二个树的副本替换第一个树的每个叶节点来形成的。例如

>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': Array([[[ 2.,  0.], [ 0.,  0.]],
                         [[ 0.,  0.], [ 0., 12.]]], dtype=float32),
             'b': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32)},
       'b': {'a': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32),
             'b': Array([[[0.      , 0.      ], [0.      , 0.      ]],
                         [[0.      , 0.      ], [0.      , 3.843624]]], dtype=float32)}}}

因此,jax.hessian(fun)(x) 的树结构中的每个叶节点都对应于 fun(x) 的一个叶节点和 x 的两个叶节点。对于 jax.hessian(fun)(x) 中的每个叶节点,如果相应的 fun(x) 的数组叶节点的形状为 (out_1, out_2, ...),而相应的 x 的数组叶节点的形状分别为 (in_1_1, in_1_2, ...)(in_2_1, in_2_2, ...),那么 Hessian 矩阵的叶节点的形状为 (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)。换句话说,Python 树结构表示 Hessian 矩阵的块结构,块由输入和输出 pytrees 决定。

特别是,当函数输入 x 和输出 fun(x) 都是单个数组时(如上面的 g 示例),会生成一个数组(不涉及 pytrees)。如果 fun(x) 的形状为 (out1, out2, ...),而 x 的形状为 (in1, in2, ...),那么 jax.hessian(fun)(x) 的形状为 (out1, out2, ..., in1, in2, ..., in1, in2, ...)。要将 pytrees 展平成 1D 向量,请考虑使用 jax.flatten_util.flatten_pytree()