jax.numpy.indices#
- jax.numpy.indices(dimensions, dtype=None, sparse=False)[源代码]#
生成网格索引数组。
numpy.indices()
的 JAX 实现。- 参数:
- 返回:
一个形状为
(len(dimensions), *dimensions)
的数组。如果sparse
为 False,则返回该数组;如果sparse
为 True,则返回一个长度与dimensions
相同的数组序列。- 返回类型:
另请参阅
jax.numpy.meshgrid()
: 从任意输入数组生成网格。jax.numpy.mgrid
: 使用切片语法生成密集索引。jax.numpy.ogrid
: 使用切片语法生成稀疏索引。
示例
>>> jnp.indices((2, 3)) Array([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, 2]]], dtype=int32) >>> jnp.indices((2, 3), sparse=True) (Array([[0], [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))