jax.numpy.linalg.tensorsolve#
- jax.numpy.linalg.tensorsolve(a, b, axes=None)[source]#
求解张量方程 a x = b 的 x。
JAX 实现
numpy.linalg.tensorsolve()
.- 参数:
- 返回值:
数组 x,使得在对
a
的轴进行重新排序后,tensordot(a, x, x.ndim)
等效于b
。- 返回类型:
示例
>>> key1, key2 = jax.random.split(jax.random.key(8675309)) >>> a = jax.random.normal(key1, shape=(2, 2, 4)) >>> b = jax.random.normal(key2, shape=(2, 2)) >>> x = jnp.linalg.tensorsolve(a, b) >>> x.shape (4,)
现在证明可以使用
x
通过tensordot()
来重建b
。>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) >>> jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)