jax.numpy.linalg.diagonal#
- jax.numpy.linalg.diagonal(x, /, *, offset=0)[source]#
提取矩阵或矩阵堆栈的对角线。
JAX 实现
numpy.linalg.diagonal()
.- 参数::
x (ArrayLike) – 形状为
(..., M, N)
的数组,从中提取对角线。offset (int) – 从主对角线开始的正或负偏移量。
- 返回值::
形状为
(..., K)
的数组,其中K
是指定对角线的长度。- 返回类型::
另请参见
jax.numpy.diagonal()
:提取对角线的更通用功能。jax.numpy.diag()
:从值创建对角矩阵。
示例
单个矩阵的对角线
>>> x = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> jnp.linalg.diagonal(x) Array([ 1, 6, 11], dtype=int32) >>> jnp.linalg.diagonal(x, offset=1) Array([ 2, 7, 12], dtype=int32) >>> jnp.linalg.diagonal(x, offset=-1) Array([ 5, 10], dtype=int32)
批处理对角线
>>> x = jnp.arange(24).reshape(2, 3, 4) >>> jnp.linalg.diagonal(x) Array([[ 0, 5, 10], [12, 17, 22]], dtype=int32)