jax.lax.pad#

jax.lax.pad(operand, padding_value, padding_config)[源代码]#

将低、高和/或内部填充应用于数组。

包装 XLA 的 Pad 运算符。

参数:
  • operand (ArrayLike) – 要填充的数组。

  • padding_value (ArrayLike) – 作为填充插入的值。必须与 operand 具有相同的数据类型。

  • padding_config (Sequence[tuple[int, int, int]]) – 一个整数元组 (low, high, interior) 的序列,给出在每个维度中插入的低位、高位和内部(扩张)填充的数量。

返回值:

根据 padding_config 在每个维度中插入填充值 padding_valueoperand 数组。

返回类型:

Array

示例

>>> from jax import lax
>>> import jax.numpy as jnp

用零填充一个一维数组。我们将在开头指定两个零,在结尾指定三个零。

>>> x = jnp.array([1, 2, 3, 4])
>>> lax.pad(x, 0, [(2, 3, 0)])
Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)

内部零填充一个一维数组;即在每个值之间插入一个零。

>>> lax.pad(x, 0, [(0, 0, 1)])
Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)

用值 -1 在开头和结尾填充一个二维数组,在每个维度上的填充大小为 2。

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)])
Array([[-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1,  1,  2,  3, -1, -1],
       [-1, -1,  4,  5,  6, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)