jax.lax.slice#
- jax.lax.slice(operand, start_indices, limit_indices, strides=None)[source]#
包装 XLA 的 Slice 运算符。
- 参数:
- 返回值:
切片的数组
- 返回值类型:
示例
以下是一些简单二维切片的示例
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> lax.slice(x, (1, 0), (3, 2)) Array([[4, 5], [8, 9]], dtype=int32)
>>> lax.slice(x, (0, 0), (3, 4), (1, 2)) Array([[ 0, 2], [ 4, 6], [ 8, 10]], dtype=int32)
这两个示例等效于以下 Python 切片语法
>>> x[1:3, 0:2] Array([[4, 5], [8, 9]], dtype=int32)
>>> x[0:3, 0:4:2] Array([[ 0, 2], [ 4, 6], [ 8, 10]], dtype=int32)