jax.numpy.diag_indices#
- jax.numpy.diag_indices(n, ndim=2)[源代码]#
返回用于访问多维数组主对角线的索引。
numpy.diag_indices()
的 JAX 实现。- 参数:
- 返回:
一个数组元组,每个长度为 n,包含访问主对角线的索引。
- 返回类型:
示例
>>> jnp.diag_indices(3) (Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32)) >>> jnp.diag_indices(4, ndim=3) (Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32))