jax.lax.pad#
- jax.lax.pad(operand, padding_value, padding_config)[源代码]#
将低、高和/或内部填充应用于数组。
包装 XLA 的 Pad 运算符。
- 参数:
- 返回值:
根据
padding_config
在每个维度中插入填充值padding_value
的operand
数组。- 返回类型:
示例
>>> from jax import lax >>> import jax.numpy as jnp
用零填充一个一维数组。我们将在开头指定两个零,在结尾指定三个零。
>>> x = jnp.array([1, 2, 3, 4]) >>> lax.pad(x, 0, [(2, 3, 0)]) Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
用内部零填充一个一维数组;即在每个值之间插入一个零。
>>> lax.pad(x, 0, [(0, 0, 1)]) Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
用值
-1
在开头和结尾填充一个二维数组,在每个维度上的填充大小为 2。>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) Array([[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, 1, 2, 3, -1, -1], [-1, -1, 4, 5, 6, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)