jax.numpy.diag_indices

jax.numpy.diag_indices#

jax.numpy.diag_indices(n, ndim=2)[source]#

返回用于访问多维数组主对角线的索引。

JAX 实现的 numpy.diag_indices().

参数::
  • n (int) – int。正方形数组每个维度的尺寸。

  • ndim (int) – 可选,int,默认=2。数组的维度数。

返回值::

一个数组元组,每个数组的长度为 n,包含访问主对角线的索引。

返回类型::

tuple[Array, …]

示例

>>> jnp.diag_indices(3)
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
>>> jnp.diag_indices(4, ndim=3)
(Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32))