jax.numpy.diag_indices#

jax.numpy.diag_indices(n, ndim=2)[源代码]#

返回用于访问多维数组主对角线的索引。

numpy.diag_indices() 的 JAX 实现。

参数:
  • n (int) – int。正方形数组的每个维度的大小。

  • ndim (int) – 可选,int,默认值=2。数组的维度数。

返回:

一个数组元组,每个长度为 n,包含访问主对角线的索引。

返回类型:

tuple[Array, …]

示例

>>> jnp.diag_indices(3)
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
>>> jnp.diag_indices(4, ndim=3)
(Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32))