jax.numpy.indices#

jax.numpy.indices(dimensions, dtype=None, sparse=False)[源代码]#

生成网格索引数组。

JAX 实现的 numpy.indices()

参数:
  • dimensions (Sequence[int]) – 网格的形状。

  • dtype (DTypeLike | None | None) – 索引的 dtype(默认为整数)。

  • sparse (bool) – 如果为 True,则返回稀疏索引。默认为 False,返回稠密索引。

返回:

形状为 (len(dimensions), *dimensions) 的数组。如果 sparse 为 False,或者如果 sparse 为 True,则返回与 dimensions 长度相同的数组序列。

返回类型:

Array | tuple[Array, …]

另请参阅

示例

>>> jnp.indices((2, 3))
Array([[[0, 0, 0],
        [1, 1, 1]],

       [[0, 1, 2],
        [0, 1, 2]]], dtype=int32)
>>> jnp.indices((2, 3), sparse=True)
(Array([[0],
       [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))