jax.lax.dynamic_slice#
- jax.lax.dynamic_slice(operand, start_indices, slice_sizes)[source]#
包装 XLA 的 DynamicSlice 运算符。
- 参数:
- 返回值:
包含切片的数组。
- 返回值类型:
示例
这是一个简单的二维动态切片
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3)) Array([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
注意,对于请求的切片超出数组边界的这种情况,可能会出现意外的行为;在这种情况下,起始索引将被调整以返回请求大小的切片
>>> dynamic_slice(x, (1, 1), (2, 4)) Array([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)