jax.numpy.linalg.det

内容

jax.numpy.linalg.det#

jax.numpy.linalg.det(a)[source]#

计算数组的行列式。

JAX 实现 numpy.linalg.det().

参数:

**a** (ArrayLike) – 形状为 (..., M, M) 的数组,需要计算其行列式。

返回值:

形状为 a.shape[:-2] 的行列式数组。

返回类型:

数组

另请参阅

jax.scipy.linalg.det(): 行列式的 SciPy 风格 API。

示例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.linalg.det(a)
Array(-2., dtype=float32)