jax.numpy.linalg.tensorinv

内容

jax.numpy.linalg.tensorinv#

jax.numpy.linalg.tensorinv(a, ind=2)[source]#

计算数组的张量逆。

JAX 实现 numpy.linalg.tensorinv()

这计算了具有相同 ind 值的 tensordot() 操作的逆。

参数:
  • a (ArrayLike) – 要反转的数组。必须有 prod(a.shape[:ind]) == prod(a.shape[ind:])

  • ind (int) – 指定张量积中索引数量的正整数。

返回:

形状为 (*a.shape[ind:], *a.shape[:ind]) 的数组,包含 a 的张量逆。

返回类型:

Array

示例

>>> key = jax.random.key(1337)
>>> x = jax.random.normal(key, shape=(2, 2, 4))
>>> xinv = jnp.linalg.tensorinv(x, 2)
>>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2)
>>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4)
Array(True, dtype=bool)