jax.numpy.indices

内容

jax.numpy.indices#

jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: Literal[False] = False) Array[source]#
jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, *, sparse: Literal[True]) tuple[Array, ...]
jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) Array | tuple[Array, ...]

返回表示网格索引的数组。

LAX 后端实现 numpy.indices().

原始文档字符串如下。

计算一个数组,其中子数组包含索引值 0、1、…,仅沿相应的轴变化。

参数:
  • dimensions (序列 of 整数) – 网格的形状。

  • dtype (dtype, 可选) – 结果的数据类型。

  • sparse (布尔值, 可选) – 返回网格的稀疏表示而不是密集表示。默认为 False。

返回值:

grid

如果 sparse 为 False

返回一个网格索引数组,grid.shape = (len(dimensions),) + tuple(dimensions)

如果 sparse 为 True

返回一个数组元组,其中 grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1),其中 dimensions[i] 位于第 i 个位置

返回类型:

一个 ndarray 或 元组 of ndarrays