jax.numpy.trace

内容

jax.numpy.trace#

jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[source]#

返回数组对角线的总和。

LAX 后端实现的 numpy.trace().

原始文档字符串如下。

如果 a 是二维数组,则返回其对角线上具有给定偏移量的元素之和,即所有 a[i,i+offset] 元素的和,其中 i 为索引。

如果 a 的维度超过二维,则使用由 axis1 和 axis2 指定的轴来确定返回其对角线的二维子数组。结果数组的形状与 a 的形状相同,但去掉了 axis1axis2 轴。

参数:
  • a (array_like) – 输入数组,从中提取对角线。

  • offset (int, 可选) – 对角线相对于主对角线的偏移量。可以是正数或负数。默认值为 0。

  • axis1 (int, 可选) – 用于指定从其中提取对角线的二维子数组的第一个和第二个轴。默认值为 a 的前两个轴。

  • axis2 (int, 可选) – 用于指定从其中提取对角线的二维子数组的第一个和第二个轴。默认值为 a 的前两个轴。

  • dtype (dtype, 可选) – 确定返回数组和累加器的数据类型。如果 dtype 的值为 None 且 a 是精度低于默认整数精度的整数类型,则使用默认整数精度。否则,精度与 a 的精度相同。

  • out (None)

返回值:

sum_along_diagonals – 如果 a 是二维数组,则返回对角线上的元素之和。如果 a 的维度大于二维,则返回一个包含对角线元素之和的数组。

返回类型:

ndarray