jax.numpy.pad#

jax.numpy.pad(array, pad_width, mode='constant', **kwargs)[源代码]#

为数组添加填充。

numpy.pad() 的 JAX 实现。

参数:
  • array (ArrayLike) – 要填充的数组。

  • pad_width (PadValueLike[int | Array | np.ndarray]) –

    指定数组每个维度的填充宽度。可以分别指定数组之前之后的填充宽度。选项如下:

    • int(int,):在每个数组维度的前后填充相同数量的值。

    • (before, after):在每个数组之前填充 before 个元素,之后填充 after 个元素。

    • ((before_1, after_1), (before_2, after_2), ... (before_N, after_N)):为每个数组维度指定不同的 beforeafter 值。

  • mode (str | Callable[..., Any]) –

    一个字符串或可调用对象。支持的填充模式有:

    • 'constant'(默认):用一个常量值填充,默认为零。

    • 'empty':用空值(即零)填充。

    • 'edge':用数组的边缘值填充。

    • 'wrap':通过包裹数组进行填充。

    • 'linear_ramp':使用线性斜坡填充到指定的 end_values

    • 'maximum':用最大值填充。

    • 'mean':用平均值填充。

    • 'median':用中位数填充。

    • 'minimum':用最小值填充。

    • 'reflect':通过反射填充。

    • 'symmetric':通过对称反射填充。

    • <callable>:一个可调用函数。请参见下面的“备注”。

  • constant_values – 用于 mode = 'constant'。指定要填充的常量值。

  • stat_length – 用于 mode in ['maximum', 'mean', 'median', 'minimum']。一个整数或元组,指定在计算统计信息时要使用的边缘值的数量。

  • end_values – 用于 mode = 'linear_ramp'。指定将填充值斜坡到目标的值。

  • reflect_type – 用于 mode in ['reflect', 'symmetric']。指定是使用偶数反射还是奇数反射。

返回:

array 的填充副本。

返回类型:

Array

说明

mode 为可调用对象时,它应具有以下签名:

def pad_func(row: Array, pad_width: tuple[int, int],
             iaxis: int, kwargs: dict) -> Array:
  ...

这里 row 是沿轴 iaxis 的填充数组的一维切片,填充值用零填充。pad_width 是一个元组,指定 (before, after) 填充大小,kwargs 是传递给 jax.numpy.pad() 函数的任何其他关键字参数。

请注意,虽然在 NumPy 中,该函数应就地修改 row,但在 JAX 中,该函数应返回修改后的 row。在 JAX 中,自定义填充函数将使用 jax.vmap() 转换映射到填充轴上。

另请参阅

示例

用零填充一维数组

>>> x = jnp.array([10, 20, 30, 40])
>>> jnp.pad(x, 2)
Array([ 0,  0, 10, 20, 30, 40,  0,  0], dtype=int32)
>>> jnp.pad(x, (2, 4))
Array([ 0,  0, 10, 20, 30, 40,  0,  0,  0,  0], dtype=int32)

用指定的值填充一维数组

>>> jnp.pad(x, 2, constant_values=99)
Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)

用平均数组值填充一维数组

>>> jnp.pad(x, 2, mode='mean')
Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)

用反射值填充一维数组

>>> jnp.pad(x, 2, mode='reflect')
Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)

在每个维度上使用不同的填充来填充二维数组

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.pad(x, ((1, 2), (3, 0)))
Array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 2, 3],
       [0, 0, 0, 4, 5, 6],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]], dtype=int32)

使用自定义填充函数填充一维数组

>>> def custom_pad(row, pad_width, iaxis, kwargs):
...   # row represents a 1D slice of the zero-padded array.
...   before, after = pad_width
...   before_value = kwargs.get('before_value', 0)
...   after_value = kwargs.get('after_value', 0)
...   row = row.at[:before].set(before_value)
...   return row.at[len(row) - after:].set(after_value)
>>> x = jnp.array([2, 3, 4])
>>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10)
Array([-10, -10,   2,   3,   4,  10,  10], dtype=int32)