jax.numpy.linalg.matrix_power#

jax.numpy.linalg.matrix_power(a, n)[源代码]#

将一个方阵提升到整数幂。

通过重复平方实现的 numpy.linalg.matrix_power() 的 JAX 实现。

参数:
  • a (ArrayLike) – 形状为 (..., M, M) 的数组,将被提升到幂 n

  • n (int) – 矩阵应被提升到的整数指数。

返回:

形状为 (..., M, M) 的数组,包含 a 的 n 次矩阵幂。

返回类型:

Array

示例

>>> a = jnp.array([[1., 2.],
...                [3., 4.]])
>>> jnp.linalg.matrix_power(a, 3)
Array([[ 37.,  54.],
       [ 81., 118.]], dtype=float32)
>>> a @ a @ a  # equivalent evaluated directly
Array([[ 37.,  54.],
       [ 81., 118.]], dtype=float32)

这也支持零次幂

>>> jnp.linalg.matrix_power(a, 0)
Array([[1., 0.],
       [0., 1.]], dtype=float32)

并且也支持负次幂

>>> with jnp.printoptions(precision=3):
...   jnp.linalg.matrix_power(a, -2)
Array([[ 5.5 , -2.5 ],
       [-3.75,  1.75]], dtype=float32)

负次幂等价于逆矩阵的 matmul

>>> inv_a = jnp.linalg.inv(a)
>>> with jnp.printoptions(precision=3):
...   inv_a @ inv_a
Array([[ 5.5 , -2.5 ],
       [-3.75,  1.75]], dtype=float32)