jax.numpy.linalg.diagonal

内容

jax.numpy.linalg.diagonal#

jax.numpy.linalg.diagonal(x, /, *, offset=0)[source]#

提取矩阵或矩阵堆栈的对角线。

JAX 实现 numpy.linalg.diagonal().

参数::
  • x (ArrayLike) – 形状为 (..., M, N) 的数组,从中提取对角线。

  • offset (int) – 从主对角线开始的正或负偏移量。

返回值::

形状为 (..., K) 的数组,其中 K 是指定对角线的长度。

返回类型::

数组

另请参见

示例

单个矩阵的对角线

>>> x = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9, 10, 11, 12]])
>>> jnp.linalg.diagonal(x)
Array([ 1,  6, 11], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=1)
Array([ 2,  7, 12], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=-1)
Array([ 5, 10], dtype=int32)

批处理对角线

>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.diagonal(x)
Array([[ 0,  5, 10],
       [12, 17, 22]], dtype=int32)