jax.numpy.linalg.tensorsolve#

jax.numpy.linalg.tensorsolve(a, b, axes=None)[源代码]#

求解张量方程 a x = b 中的 x。

JAX 实现的 numpy.linalg.tensorsolve()

参数:
  • a (类数组) – 输入数组。 通过 axes (见下文) 重新排序后,形状必须为 (*b.shape, *x.shape)

  • b (类数组) – 等式右侧的数组。

  • axes (tuple[int, ...] | None | None) – 可选元组,指定应移动到 a 末尾的轴。

返回:

数组 x,使得在重新排序 a 的轴后, tensordot(a, x, x.ndim) 等效于 b

返回类型:

数组

示例

>>> key1, key2 = jax.random.split(jax.random.key(8675309))
>>> a = jax.random.normal(key1, shape=(2, 2, 4))
>>> b = jax.random.normal(key2, shape=(2, 2))
>>> x = jnp.linalg.tensorsolve(a, b)
>>> x.shape
(4,)

现在展示如何使用 tensordot() 使用 x 重建 b

>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim)
>>> jnp.allclose(b, b_reconstructed)
Array(True, dtype=bool)