jax.scipy.linalg.det#

jax.scipy.linalg.det(a, overwrite_a=False, check_finite=True)[源代码]#

计算矩阵的行列式

scipy.linalg.det() 的 JAX 实现。

参数:
  • a (ArrayLike) – 输入数组,形状为 (..., N, N)

  • overwrite_a (bool) – JAX 未使用

  • check_finite (bool) – JAX 未使用

返回类型:

Array

返回

行列式,形状为 a.shape[:-2]

另请参阅

jax.numpy.linalg.det():NumPy 风格的行列式 API

示例

小型 2D 数组的行列式

>>> x = jnp.array([[1., 2.],
...                [3., 4.]])
>>> jax.scipy.linalg.det(x)
Array(-2., dtype=float32)

多个 2D 数组的批处理行列式

>>> x = jnp.array([[[1., 2.],
...                 [3., 4.]],
...                [[8., 5.],
...                 [7., 9.]]])
>>> jax.scipy.linalg.det(x)
Array([-2., 37.], dtype=float32)