jax.numpy.clip#

jax.numpy.clip(arr=None, /, min=None, max=None, *, a=Deprecated, a_min=Deprecated, a_max=Deprecated)[源代码]#

将数组值裁剪到指定的范围。

JAX 实现的 numpy.clip()

参数:
  • arr (类数组 | None) – 要裁剪的 N 维数组。

  • min (类数组 | None) – 裁剪范围的可选最小值;如果为 None (默认),则结果不会被裁剪到任何最小值。如果指定,它应该与 arrmax 具有广播兼容性。

  • max (类数组 | None) – 裁剪范围的可选最大值;如果为 None (默认),则结果不会被裁剪到任何最大值。如果指定,它应该与 arrmin 具有广播兼容性。

  • a (类数组 | DeprecatedArg) – arr 参数的已弃用别名。如果使用,将导致 DeprecationWarning

  • a_min (类数组 | None | DeprecatedArg) – min 参数的已弃用别名。如果使用,将导致 DeprecationWarning

  • a_max (类数组 | None | DeprecatedArg) – max 参数的已弃用别名。如果使用,将导致 DeprecationWarning

返回:

一个包含来自 arr 的值的数组,其中小于 min 的值设置为 min,大于 max 的值设置为 max

返回类型:

数组

另请参阅

示例

>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
>>> jnp.clip(arr, 2, 5)
Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)