jax.numpy.linalg.tensorsolve

内容

jax.numpy.linalg.tensorsolve#

jax.numpy.linalg.tensorsolve(a, b, axes=None)[source]#

求解张量方程 a x = b 的 x。

JAX 实现 numpy.linalg.tensorsolve().

参数:
  • a (ArrayLike) – 输入数组。在通过 axes 重新排序后(见下文),形状必须为 (*b.shape, *x.shape).

  • b (ArrayLike) – 右侧数组。

  • axes (tuple[int, ...] | None | None) – 可选元组,指定应移至末尾的 a 的轴

返回值:

数组 x,使得在对 a 的轴进行重新排序后,tensordot(a, x, x.ndim) 等效于 b

返回类型:

Array

示例

>>> key1, key2 = jax.random.split(jax.random.key(8675309))
>>> a = jax.random.normal(key1, shape=(2, 2, 4))
>>> b = jax.random.normal(key2, shape=(2, 2))
>>> x = jnp.linalg.tensorsolve(a, b)
>>> x.shape
(4,)

现在证明可以使用 x 通过 tensordot() 来重建 b

>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim)
>>> jnp.allclose(b, b_reconstructed)
Array(True, dtype=bool)