jax.lax.dynamic_index_in_dim

jax.lax.dynamic_index_in_dim#

jax.lax.dynamic_index_in_dim(operand, index, axis=0, keepdims=True)[source]#

围绕 dynamic_slice 的便捷包装器,用于执行整数索引。

这大致等同于以下应用于指定轴的 Python 索引语法:operand[..., index]

参数:
  • operand (Array | np.ndarray) – 要切片的数组。

  • index (int | Array) – (可能动态的)起始索引

  • axis (int) – 应用切片的轴(默认为 0)

  • keepdims (bool) – 布尔值,指定输出是否应与输入具有相同的秩(默认值为 True)

返回值:

包含切片的数组。

返回类型:

Array

示例

这是一个一维示例

>>> x = jnp.arange(5)
>>> dynamic_index_in_dim(x, 1)
Array([1], dtype=int32)
>>> dynamic_index_in_dim(x, 1, keepdims=False)
Array(1, dtype=int32)

这是一个二维示例

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> dynamic_index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)