jax.lax.slice_in_dim#
- jax.lax.slice_in_dim(operand, start_index, limit_index, stride=1, axis=0)[source]#
围绕
lax.slice()
的便捷包装器,仅应用于一个维度。这实际上等同于
operand[..., start_index:limit_index:stride]
,索引应用于指定的轴。- 参数:
- 返回值:
包含切片的数组。
- 返回类型:
示例
这是一个一维示例
>>> x = jnp.arange(4) >>> lax.slice_in_dim(x, 1, 3) Array([1, 2], dtype=int32)
以下是一些二维示例
>>> x = jnp.arange(12).reshape(4, 3) >>> x Array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3) Array([[3, 4, 5], [6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1) Array([[ 1, 2], [ 4, 5], [ 7, 8], [10, 11]], dtype=int32)