jax.numpy.indices#
- jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: Literal[False] = False) Array [source]#
- jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, *, sparse: Literal[True]) tuple[Array, ...]
- jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) Array | tuple[Array, ...]
返回表示网格索引的数组。
LAX 后端实现
numpy.indices()
.原始文档字符串如下。
计算一个数组,其中子数组包含索引值 0、1、…,仅沿相应的轴变化。
- 参数:
dimensions (序列 of 整数) – 网格的形状。
dtype (dtype, 可选) – 结果的数据类型。
sparse (布尔值, 可选) – 返回网格的稀疏表示而不是密集表示。默认为 False。
- 返回值:
grid –
- 如果 sparse 为 False
返回一个网格索引数组,
grid.shape = (len(dimensions),) + tuple(dimensions)
。- 如果 sparse 为 True
返回一个数组元组,其中
grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)
,其中 dimensions[i] 位于第 i 个位置
- 返回类型:
一个 ndarray 或 元组 of ndarrays