jax.numpy.linalg.trace

内容

jax.numpy.linalg.trace#

jax.numpy.linalg.trace(x, /, *, offset=0, dtype=None)[source]#

计算矩阵的迹。

JAX 实现 numpy.linalg.trace().

参数::
  • x (ArrayLike) – 形状为 (..., M, N) 的数组,其最内层两个维度形成 MxN 矩阵,需要对其取迹。

  • offset (int) – 主对角线的正偏移或负偏移(默认值:0)。

  • dtype (DTypeLike | None | None) – 返回数组的数据类型(默认值:None)。如果为 None,则输出数据类型将与 x 的数据类型匹配,在整数类型的情况下将提升到默认精度。

返回值:

形状为 x.shape[:-2] 的批次跟踪数组

返回类型:

数组

另请参阅

示例

单个矩阵的迹

>>> x = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9, 10, 11, 12]])
>>> jnp.linalg.trace(x)
Array(18, dtype=int32)
>>> jnp.linalg.trace(x, offset=1)
Array(21, dtype=int32)
>>> jnp.linalg.trace(x, offset=-1, dtype="float32")
Array(15., dtype=float32)

批次跟踪

>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.trace(x)
Array([15, 51], dtype=int32)