jax.scipy.linalg.expm_frechet#
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) tuple[Array, Array] [source]#
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[False]) Array
- jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) Array | tuple[Array, Array]
计算矩阵指数的 Fréchet 导数。
JAX 中对
scipy.linalg.expm_frechet()
的实现- 参数:
A – 形状为
(..., N, N)
的数组E – 形状为
(..., N, N)
的数组;指定导数的方向。compute_expm – 如果为 True (默认),则计算并返回
expm(A)
。method – JAX 忽略此参数
- 返回:
如果
compute_expm
为 True,则返回一个元组(expm_A, expm_frechet_AE)
,否则返回数组expm_frechet_AE
。两个返回数组的形状都为(..., N, N)
。
示例
我们可以使用此 API 计算
A
的矩阵指数,以及其在方向E
上的导数>>> key1, key2 = jax.random.split(jax.random.key(3372)) >>> A = jax.random.normal(key1, (3, 3)) >>> E = jax.random.normal(key2, (3, 3)) >>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E)
这可以通过 JAX 的自动微分方法等效地计算;在这里,我们将使用
jax.jvp()
计算expm()
在E
方向上的导数,并得到相同的结果>>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,)) >>> jnp.allclose(expmA, expmA2) Array(True, dtype=bool) >>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2) Array(True, dtype=bool)