jax.numpy.unravel_index

jax.numpy.unravel_index#

jax.numpy.unravel_index(indices, shape)[source]#

将扁平索引转换为多维索引。

JAX 实现 numpy.unravel_index()。JAX 版本在其对越界索引的处理方面有所不同:与 NumPy 不同,JAX 支持负索引,并且越界索引会被裁剪到最近的有效值。

参数:
  • indices (ArrayLike) – 扁平索引的整数数组

  • shape (Shape) – 要索引的多维数组的形状

返回值:

解开的索引的元组

返回类型:

tuple[Array, …]

参见

jax.numpy.ravel_multi_index(): 此函数的反函数。

示例

从一个一维数组值和索引开始

>>> x = jnp.array([2., 3., 4., 5., 6., 7.])
>>> indices = jnp.array([1, 3, 5])
>>> print(x[indices])
[3. 5. 7.]

现在,如果 x 被重塑,则可以使用 unravel_indices 将扁平索引转换为访问相同条目的索引元组

>>> shape = (2, 3)
>>> x_2D = x.reshape(shape)
>>> indices_2D = jnp.unravel_index(indices, shape)
>>> indices_2D
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
>>> print(x_2D[indices_2D])
[3. 5. 7.]

反函数 ravel_multi_index 可用于获取原始索引

>>> jnp.ravel_multi_index(indices_2D, shape)
Array([1, 3, 5], dtype=int32)