jax.numpy.linalg.trace#
- jax.numpy.linalg.trace(x, /, *, offset=0, dtype=None)[source]#
计算矩阵的迹。
JAX 实现
numpy.linalg.trace()
.- 参数::
x (ArrayLike) – 形状为
(..., M, N)
的数组,其最内层两个维度形成 MxN 矩阵,需要对其取迹。offset (int) – 主对角线的正偏移或负偏移(默认值:0)。
dtype (DTypeLike | None | None) – 返回数组的数据类型(默认值:
None
)。如果为None
,则输出数据类型将与x
的数据类型匹配,在整数类型的情况下将提升到默认精度。
- 返回值:
形状为
x.shape[:-2]
的批次跟踪数组- 返回类型:
另请参阅
jax.numpy.trace()
:jax.numpy
命名空间中的类似 API。
示例
单个矩阵的迹
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.linalg.trace(x) Array(18, dtype=int32) >>> jnp.linalg.trace(x, offset=1) Array(21, dtype=int32) >>> jnp.linalg.trace(x, offset=-1, dtype="float32") Array(15., dtype=float32)
批次跟踪
>>> x = jnp.arange(24).reshape(2, 3, 4) >>> jnp.linalg.trace(x) Array([15, 51], dtype=int32)