jax.lax.pad#
- jax.lax.pad(operand, padding_value, padding_config)[源代码]#
将低位、高位和/或内部填充应用于数组。
包装 XLA 的 Pad 操作符。
- 参数:
- 返回:
根据
padding_config
,在每个维度中插入填充值padding_value
后的operand
数组。- 返回类型:
示例
>>> from jax import lax >>> import jax.numpy as jnp
用零填充一维数组,我们将在开头指定两个零,在结尾指定三个零
>>> x = jnp.array([1, 2, 3, 4]) >>> lax.pad(x, 0, [(2, 3, 0)]) Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
用内部零填充一维数组;即,在每个值之间插入一个零
>>> lax.pad(x, 0, [(0, 0, 1)]) Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
用值
-1
在开头和结尾填充二维数组,每个维度的填充大小为 2>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) Array([[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, 1, 2, 3, -1, -1], [-1, -1, 4, 5, 6, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)