jax.numpy.unravel_index#
- jax.numpy.unravel_index(indices, shape)[source]#
将扁平索引转换为多维索引。
JAX 实现
numpy.unravel_index()
。JAX 版本在其对越界索引的处理方面有所不同:与 NumPy 不同,JAX 支持负索引,并且越界索引会被裁剪到最近的有效值。参见
jax.numpy.ravel_multi_index()
: 此函数的反函数。示例
从一个一维数组值和索引开始
>>> x = jnp.array([2., 3., 4., 5., 6., 7.]) >>> indices = jnp.array([1, 3, 5]) >>> print(x[indices]) [3. 5. 7.]
现在,如果
x
被重塑,则可以使用unravel_indices
将扁平索引转换为访问相同条目的索引元组>>> shape = (2, 3) >>> x_2D = x.reshape(shape) >>> indices_2D = jnp.unravel_index(indices, shape) >>> indices_2D (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) >>> print(x_2D[indices_2D]) [3. 5. 7.]
反函数
ravel_multi_index
可用于获取原始索引>>> jnp.ravel_multi_index(indices_2D, shape) Array([1, 3, 5], dtype=int32)