jax.numpy.diagonal#

jax.numpy.diagonal(a, offset=0, axis1=0, axis2=1)[源代码]#

返回数组的指定对角线。

JAX 对 numpy.diagonal() 的实现。

JAX 版本总是返回输入的副本,尽管如果在 JIT 编译中使用,编译器可能会避免复制。

参数:
  • a (ArrayLike) – 输入数组。必须至少是 2 维的。

  • offset (int) – 可选,默认值=0。相对于主对角线的对角线偏移量。必须是静态整数值。可以是正数或负数。

  • axis1 (int) – 可选,默认值=0。沿着它获取对角线的第一个轴。

  • axis2 (int) –

    可选,默认值=1。沿着它获取对角线的第二个轴。

    返回

    对于 2D 输入,返回 1D 数组;通常对于 N 维输入,返回 N-1 维数组。

返回类型:

Array

示例

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.diagonal(x)
Array([1, 5, 9], dtype=int32)
>>> jnp.diagonal(x, offset=1)
Array([2, 6], dtype=int32)
>>> jnp.diagonal(x, offset=-1)
Array([4, 8], dtype=int32)