jax.lax.index_in_dim

jax.lax.index_in_dim#

jax.lax.index_in_dim(operand, index, axis=0, keepdims=True)[source]#

围绕 lax.slice() 的便捷包装器,用于执行 int 索引。

这实际上等同于 operand[..., start_index:limit_index:stride],索引应用于指定的轴。

参数:
  • **operand** (Array | np.ndarray) – 要索引的数组。

  • **index** (int) – 整数索引

  • **axis** (int) – 应用索引的轴(默认为 0)

  • **keepdims** (bool) – 布尔值,指定输出数组是否应保留输入的秩(默认值为 True)

返回值:

指定索引处的子数组。

返回类型:

数组

示例

这是一个一维示例

>>> x = jnp.arange(4)
>>> lax.index_in_dim(x, 2)
Array([2], dtype=int32)
>>> lax.index_in_dim(x, 2, keepdims=False)
Array(2, dtype=int32)

这里是一些二维示例

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> lax.index_in_dim(x, 1)
Array([[4, 5, 6, 7]], dtype=int32)
>>> lax.index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)