jax.lax.slice_in_dim

jax.lax.slice_in_dim#

jax.lax.slice_in_dim(operand, start_index, limit_index, stride=1, axis=0)[source]#

围绕 lax.slice() 的便捷包装器,仅应用于一个维度。

这实际上等同于 operand[..., start_index:limit_index:stride],索引应用于指定的轴。

参数:
  • operand (Array | np.ndarray) – 要切片的数组。

  • start_index (int | None) – 可选的起始索引(默认为零)

  • limit_index (int | None) – 可选的结束索引(默认为 operand.shape[axis])

  • stride (int) – 可选的步长(默认为 1)

  • (int) – 应用切片的轴(默认为 0)

返回值:

包含切片的数组。

返回类型:

数组

示例

这是一个一维示例

>>> x = jnp.arange(4)
>>> lax.slice_in_dim(x, 1, 3)
Array([1, 2], dtype=int32)

以下是一些二维示例

>>> x = jnp.arange(12).reshape(4, 3)
>>> x
Array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3)
Array([[3, 4, 5],
       [6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1)
Array([[ 1,  2],
       [ 4,  5],
       [ 7,  8],
       [10, 11]], dtype=int32)