jax.numpy.unravel_index#

jax.numpy.unravel_index(indices, shape)[源代码]#

将扁平索引转换为多维索引。

JAX 中 numpy.unravel_index() 的实现。 JAX 版本在处理越界索引方面有所不同:与 NumPy 不同,它支持负索引,并且越界索引会被裁剪到最近的有效值。

参数:
  • indices (ArrayLike) – 扁平索引的整数数组

  • shape (Shape) – 要索引的多维数组的形状

返回:

解开的索引的元组

返回类型:

tuple[Array, …]

另请参阅

jax.numpy.ravel_multi_index():此函数的逆函数。

示例

从一维数组值和索引开始

>>> x = jnp.array([2., 3., 4., 5., 6., 7.])
>>> indices = jnp.array([1, 3, 5])
>>> print(x[indices])
[3. 5. 7.]

现在,如果 x 被重塑,unravel_indices 可以用来将扁平索引转换为访问相同条目的索引元组

>>> shape = (2, 3)
>>> x_2D = x.reshape(shape)
>>> indices_2D = jnp.unravel_index(indices, shape)
>>> indices_2D
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
>>> print(x_2D[indices_2D])
[3. 5. 7.]

逆函数 ravel_multi_index 可用于获取原始索引

>>> jnp.ravel_multi_index(indices_2D, shape)
Array([1, 3, 5], dtype=int32)