jax.lax.dynamic_index_in_dim#
- jax.lax.dynamic_index_in_dim(operand, index, axis=0, keepdims=True)[源代码]#
为了执行整数索引的 dynamic_slice 的便捷封装。
这大致等同于沿指定轴应用的以下 Python 索引语法:
operand[..., index]
。- 参数:
operand (Array | np.ndarray) – 要切片的数组。
index (int | Array) – (可能动态的)起始索引
axis (int) – 应用切片的轴(默认为 0)
keepdims (bool) – 一个布尔值,指定输出是否应与输入具有相同的秩(默认 = True)
- 返回:
包含切片的数组。
- 返回类型:
示例
这是一个一维示例
>>> x = jnp.arange(5) >>> dynamic_index_in_dim(x, 1) Array([1], dtype=int32)
>>> dynamic_index_in_dim(x, 1, keepdims=False) Array(1, dtype=int32)
这是一个二维示例
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_index_in_dim(x, 1, axis=1, keepdims=False) Array([1, 5, 9], dtype=int32)