jax.lax.dynamic_slice#
- jax.lax.dynamic_slice(operand, start_indices, slice_sizes)[源代码]#
包装 XLA 的 DynamicSlice 运算符。
- 参数:
- 返回:
包含切片的数组。
- 返回类型:
示例
这是一个简单的二维动态切片
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3)) Array([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
请注意,当请求的切片超出数组边界时,可能会出现令人惊讶的行为;在这种情况下,起始索引会被调整以返回请求大小的切片。
>>> dynamic_slice(x, (1, 1), (2, 4)) Array([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)