jax.scipy.linalg.det

内容

jax.scipy.linalg.det#

jax.scipy.linalg.det(a, overwrite_a=False, check_finite=True)[source]#

计算矩阵的行列式

JAX 实现 scipy.linalg.det().

参数:
  • a (ArrayLike) – 输入数组,形状为 (..., N, N)

  • overwrite_a (bool) – JAX 未使用

  • check_finite (bool) – JAX 未使用

返回类型:

数组

返回值

形状为 a.shape[:-2] 的行列式

另请参阅

jax.numpy.linalg.det(): NumPy风格的行列式API

示例

小型二维数组的行列式

>>> x = jnp.array([[1., 2.],
...                [3., 4.]])
>>> jax.scipy.linalg.det(x)
Array(-2., dtype=float32)

多个二维数组的批量行列式

>>> x = jnp.array([[[1., 2.],
...                 [3., 4.]],
...                [[8., 5.],
...                 [7., 9.]]])
>>> jax.scipy.linalg.det(x)
Array([-2., 37.], dtype=float32)