jax.numpy.clip#
- jax.numpy.clip(arr=None, /, min=None, max=None, *, a=Deprecated, a_min=Deprecated, a_max=Deprecated)[源代码]#
将数组值裁剪到指定的范围。
JAX 实现的
numpy.clip()
。- 参数:
arr (类数组 | None) – 要裁剪的 N 维数组。
min (类数组 | None) – 裁剪范围的可选最小值;如果为
None
(默认),则结果不会被裁剪到任何最小值。如果指定,它应该与arr
和max
具有广播兼容性。max (类数组 | None) – 裁剪范围的可选最大值;如果为
None
(默认),则结果不会被裁剪到任何最大值。如果指定,它应该与arr
和min
具有广播兼容性。a (类数组 | DeprecatedArg) –
arr
参数的已弃用别名。如果使用,将导致DeprecationWarning
。a_min (类数组 | None | DeprecatedArg) –
min
参数的已弃用别名。如果使用,将导致DeprecationWarning
。a_max (类数组 | None | DeprecatedArg) –
max
参数的已弃用别名。如果使用,将导致DeprecationWarning
。
- 返回:
一个包含来自
arr
的值的数组,其中小于min
的值设置为min
,大于max
的值设置为max
。- 返回类型:
另请参阅
jax.numpy.minimum()
:计算两个数组的逐元素最小值。jax.numpy.maximum()
:计算两个数组的逐元素最大值。
示例
>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) >>> jnp.clip(arr, 2, 5) Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)