jax.numpy.linalg.tensorsolve#
- jax.numpy.linalg.tensorsolve(a, b, axes=None)[源代码]#
求解张量方程 a x = b 中的 x。
JAX 实现的
numpy.linalg.tensorsolve()
。- 参数:
a (类数组) – 输入数组。 通过
axes
(见下文) 重新排序后,形状必须为(*b.shape, *x.shape)
。b (类数组) – 等式右侧的数组。
axes (tuple[int, ...] | None | None) – 可选元组,指定应移动到
a
末尾的轴。
- 返回:
数组 x,使得在重新排序
a
的轴后,tensordot(a, x, x.ndim)
等效于b
。- 返回类型:
示例
>>> key1, key2 = jax.random.split(jax.random.key(8675309)) >>> a = jax.random.normal(key1, shape=(2, 2, 4)) >>> b = jax.random.normal(key2, shape=(2, 2)) >>> x = jnp.linalg.tensorsolve(a, b) >>> x.shape (4,)
现在展示如何使用
tensordot()
使用x
重建b
>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) >>> jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)