jax.numpy.clip#
- jax.numpy.clip(arr=None, /, min=None, max=None, *, a=Deprecated, a_min=Deprecated, a_max=Deprecated)[source]#
将数组值裁剪到指定范围。
JAX 实现
numpy.clip()
.- 参数:
arr (ArrayLike | None) – 要裁剪的 N 维数组。
min (ArrayLike | None) – 裁剪范围的可选最小值;如果
None
(默认值),则结果不会裁剪到任何最小值。如果指定,它应该与arr
和max
广播兼容。max (ArrayLike | None) – 裁剪范围的可选最大值;如果
None
(默认值),则结果不会裁剪到任何最大值。如果指定,它应该与arr
和min
广播兼容。a (ArrayLike | DeprecatedArg) –
arr
参数的已弃用别名。如果使用,将导致DeprecationWarning
。a_min (ArrayLike | None | DeprecatedArg) –
min
参数的已弃用别名。如果使用,将导致DeprecationWarning
。a_max (ArrayLike | None | DeprecatedArg) –
max
参数的已弃用别名。如果使用,将导致DeprecationWarning
。
- 返回:
一个包含来自
arr
的值的数组,其中小于min
的值设置为min
,大于max
的值设置为max
。- 返回类型:
另请参阅
jax.numpy.minimum()
: 计算两个数组的元素级最小值。jax.numpy.maximum()
: 计算两个数组的元素级最大值。
示例
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) >>> jnp.clip(arr, 2, 5) Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)