jax.numpy.diag_indices#
- jax.numpy.diag_indices(n, ndim=2)[source]#
返回用于访问多维数组主对角线的索引。
JAX 实现的
numpy.diag_indices()
.- 参数::
- 返回值::
一个数组元组,每个数组的长度为 n,包含访问主对角线的索引。
- 返回类型::
示例
>>> jnp.diag_indices(3) (Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32)) >>> jnp.diag_indices(4, ndim=3) (Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32))