jax.numpy.pad#
- jax.numpy.pad(array, pad_width, mode='constant', **kwargs)[源代码]#
为数组添加填充。
numpy.pad()
的 JAX 实现。- 参数:
array (ArrayLike) – 要填充的数组。
pad_width (PadValueLike[int | Array | np.ndarray]) –
指定数组每个维度的填充宽度。可以分别指定数组之前和之后的填充宽度。选项如下:
int
或(int,)
:在每个数组维度的前后填充相同数量的值。(before, after)
:在每个数组之前填充before
个元素,之后填充after
个元素。((before_1, after_1), (before_2, after_2), ... (before_N, after_N))
:为每个数组维度指定不同的before
和after
值。
mode (str | Callable[..., Any]) –
一个字符串或可调用对象。支持的填充模式有:
'constant'
(默认):用一个常量值填充,默认为零。'empty'
:用空值(即零)填充。'edge'
:用数组的边缘值填充。'wrap'
:通过包裹数组进行填充。'linear_ramp'
:使用线性斜坡填充到指定的end_values
。'maximum'
:用最大值填充。'mean'
:用平均值填充。'median'
:用中位数填充。'minimum'
:用最小值填充。'reflect'
:通过反射填充。'symmetric'
:通过对称反射填充。<callable>
:一个可调用函数。请参见下面的“备注”。
constant_values – 用于
mode = 'constant'
。指定要填充的常量值。stat_length – 用于
mode in ['maximum', 'mean', 'median', 'minimum']
。一个整数或元组,指定在计算统计信息时要使用的边缘值的数量。end_values – 用于
mode = 'linear_ramp'
。指定将填充值斜坡到目标的值。reflect_type – 用于
mode in ['reflect', 'symmetric']
。指定是使用偶数反射还是奇数反射。
- 返回:
array
的填充副本。- 返回类型:
说明
当
mode
为可调用对象时,它应具有以下签名:def pad_func(row: Array, pad_width: tuple[int, int], iaxis: int, kwargs: dict) -> Array: ...
这里
row
是沿轴iaxis
的填充数组的一维切片,填充值用零填充。pad_width
是一个元组,指定(before, after)
填充大小,kwargs
是传递给jax.numpy.pad()
函数的任何其他关键字参数。请注意,虽然在 NumPy 中,该函数应就地修改
row
,但在 JAX 中,该函数应返回修改后的row
。在 JAX 中,自定义填充函数将使用jax.vmap()
转换映射到填充轴上。另请参阅
jax.numpy.resize()
:调整数组大小jax.numpy.tile()
:通过平铺较小的数组来创建更大的数组。jax.numpy.repeat()
:通过重复较小数组的值来创建更大的数组。
示例
用零填充一维数组
>>> x = jnp.array([10, 20, 30, 40]) >>> jnp.pad(x, 2) Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32) >>> jnp.pad(x, (2, 4)) Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32)
用指定的值填充一维数组
>>> jnp.pad(x, 2, constant_values=99) Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)
用平均数组值填充一维数组
>>> jnp.pad(x, 2, mode='mean') Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)
用反射值填充一维数组
>>> jnp.pad(x, 2, mode='reflect') Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)
在每个维度上使用不同的填充来填充二维数组
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.pad(x, ((1, 2), (3, 0))) Array([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 2, 3], [0, 0, 0, 4, 5, 6], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=int32)
使用自定义填充函数填充一维数组
>>> def custom_pad(row, pad_width, iaxis, kwargs): ... # row represents a 1D slice of the zero-padded array. ... before, after = pad_width ... before_value = kwargs.get('before_value', 0) ... after_value = kwargs.get('after_value', 0) ... row = row.at[:before].set(before_value) ... return row.at[len(row) - after:].set(after_value) >>> x = jnp.array([2, 3, 4]) >>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10) Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32)