jax.numpy.indices#
- jax.numpy.indices(dimensions, dtype=None, sparse=False)[源代码]#
生成网格索引数组。
JAX 实现的
numpy.indices()
。- 参数:
- 返回:
形状为
(len(dimensions), *dimensions)
的数组。如果sparse
为 False,或者如果sparse
为 True,则返回与dimensions
长度相同的数组序列。- 返回类型:
另请参阅
jax.numpy.meshgrid()
:从任意输入数组生成网格。jax.numpy.mgrid
:使用切片语法生成密集索引。jax.numpy.ogrid
:使用切片语法生成稀疏索引。
示例
>>> jnp.indices((2, 3)) Array([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, 2]]], dtype=int32) >>> jnp.indices((2, 3), sparse=True) (Array([[0], [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))