jax.lax.pad#

jax.lax.pad(operand, padding_value, padding_config)[源代码]#

将低位、高位和/或内部填充应用于数组。

包装 XLA 的 Pad 操作符。

参数
  • operand (ArrayLike) – 要填充的数组。

  • padding_value (ArrayLike) – 要作为填充插入的值。必须与 operand 具有相同的数据类型。

  • padding_config (Sequence[tuple[int, int, int]]) – 一个 (low, high, interior) 整数元组序列,给出要在每个维度中插入的低位、高位和内部(扩张)填充量。

返回

根据 padding_config,在每个维度中插入填充值 padding_value 后的 operand 数组。

返回类型:

数组

示例

>>> from jax import lax
>>> import jax.numpy as jnp

用零填充一维数组,我们将在开头指定两个零,在结尾指定三个零

>>> x = jnp.array([1, 2, 3, 4])
>>> lax.pad(x, 0, [(2, 3, 0)])
Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)

内部零填充一维数组;即,在每个值之间插入一个零

>>> lax.pad(x, 0, [(0, 0, 1)])
Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)

用值 -1 在开头和结尾填充二维数组,每个维度的填充大小为 2

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)])
Array([[-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1,  1,  2,  3, -1, -1],
       [-1, -1,  4,  5,  6, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)