jax.hessian#
- jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[source]#
作为密集数组的
fun
的 Hessian。- 参数:
fun (Callable) – 要计算其 Hessian 的函数。其参数在由
argnums
指定的位置应为数组、标量或其标准 Python 容器。它应返回数组、标量或其标准 Python 容器。argnums (int | Sequence[int]) – 可选,整数或整数序列。指定要相对于其进行微分的 positional 参数(默认值为
0
)。has_aux (bool) – 可选,布尔值。指示
fun
是否返回一对,其中第一个元素被认为是将要微分的数学函数的输出,而第二个元素是辅助数据。默认值为 False。全纯 (布尔) – 可选,布尔值。指示
fun
是否保证为全纯函数。默认值为 False。
- 返回值:
一个与
fun
具有相同参数的函数,用于评估fun
的 Hessian 矩阵。- 返回类型:
可调用对象
>>> import jax >>> >>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6. -2.] [ -2. -480.]]
hessian()
是 Hessian 矩阵通常定义的泛化,支持嵌套的 Python 容器(即 pytrees)作为输入和输出。jax.hessian(fun)(x)
的树结构是通过形成fun(x)
结构与两个x
结构的树积而得到的。两个树结构的树积是通过用第二个树的副本替换第一个树的每个叶节点来形成的。例如>>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': Array([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
因此,
jax.hessian(fun)(x)
的树结构中的每个叶节点都对应于fun(x)
的一个叶节点和x
的两个叶节点。对于jax.hessian(fun)(x)
中的每个叶节点,如果相应的fun(x)
的数组叶节点的形状为(out_1, out_2, ...)
,而相应的x
的数组叶节点的形状分别为(in_1_1, in_1_2, ...)
和(in_2_1, in_2_2, ...)
,那么 Hessian 矩阵的叶节点的形状为(out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)
。换句话说,Python 树结构表示 Hessian 矩阵的块结构,块由输入和输出 pytrees 决定。特别是,当函数输入
x
和输出fun(x)
都是单个数组时(如上面的g
示例),会生成一个数组(不涉及 pytrees)。如果fun(x)
的形状为(out1, out2, ...)
,而x
的形状为(in1, in2, ...)
,那么jax.hessian(fun)(x)
的形状为(out1, out2, ..., in1, in2, ..., in1, in2, ...)
。要将 pytrees 展平成 1D 向量,请考虑使用jax.flatten_util.flatten_pytree()
。