jax.numpy.unravel_index#
- jax.numpy.unravel_index(indices, shape)[源代码]#
将扁平索引转换为多维索引。
JAX 中
numpy.unravel_index()
的实现。 JAX 版本在处理越界索引方面有所不同:与 NumPy 不同,它支持负索引,并且越界索引会被裁剪到最近的有效值。另请参阅
jax.numpy.ravel_multi_index()
:此函数的逆函数。示例
从一维数组值和索引开始
>>> x = jnp.array([2., 3., 4., 5., 6., 7.]) >>> indices = jnp.array([1, 3, 5]) >>> print(x[indices]) [3. 5. 7.]
现在,如果
x
被重塑,unravel_indices
可以用来将扁平索引转换为访问相同条目的索引元组>>> shape = (2, 3) >>> x_2D = x.reshape(shape) >>> indices_2D = jnp.unravel_index(indices, shape) >>> indices_2D (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) >>> print(x_2D[indices_2D]) [3. 5. 7.]
逆函数
ravel_multi_index
可用于获取原始索引>>> jnp.ravel_multi_index(indices_2D, shape) Array([1, 3, 5], dtype=int32)