jax.numpy.clip

内容

jax.numpy.clip#

jax.numpy.clip(arr=None, /, min=None, max=None, *, a=Deprecated, a_min=Deprecated, a_max=Deprecated)[source]#

将数组值裁剪到指定范围。

JAX 实现 numpy.clip().

参数:
  • arr (ArrayLike | None) – 要裁剪的 N 维数组。

  • min (ArrayLike | None) – 裁剪范围的可选最小值;如果 None(默认值),则结果不会裁剪到任何最小值。如果指定,它应该与 arrmax 广播兼容。

  • max (ArrayLike | None) – 裁剪范围的可选最大值;如果 None(默认值),则结果不会裁剪到任何最大值。如果指定,它应该与 arrmin 广播兼容。

  • a (ArrayLike | DeprecatedArg) – arr 参数的已弃用别名。如果使用,将导致 DeprecationWarning

  • a_min (ArrayLike | None | DeprecatedArg) – min 参数的已弃用别名。如果使用,将导致 DeprecationWarning

  • a_max (ArrayLike | None | DeprecatedArg) – max 参数的已弃用别名。如果使用,将导致 DeprecationWarning

返回:

一个包含来自 arr 的值的数组,其中小于 min 的值设置为 min,大于 max 的值设置为 max

返回类型:

数组

另请参阅

示例

>>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
>>> jnp.clip(arr, 2, 5)
Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32)