jax.numpy.linalg.diagonal#

jax.numpy.linalg.diagonal(x, /, *, offset=0)[源代码]#

提取矩阵或矩阵堆栈的对角线。

JAX实现的numpy.linalg.diagonal()

参数:
  • x (ArrayLike) – 形状为(..., M, N)的数组,将从中提取对角线。

  • offset (int) – 偏离主对角线的正或负偏移量。

返回:

形状为(..., K)的数组,其中K是指定对角线的长度。

返回类型:

数组

另请参阅

示例

单个矩阵的对角线

>>> x = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9, 10, 11, 12]])
>>> jnp.linalg.diagonal(x)
Array([ 1,  6, 11], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=1)
Array([ 2,  7, 12], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=-1)
Array([ 5, 10], dtype=int32)

批量对角线

>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.diagonal(x)
Array([[ 0,  5, 10],
       [12, 17, 22]], dtype=int32)