jax.numpy.fmin

内容

jax.numpy.fmin#

jax.numpy.fmin(x1, x2)[source]#

返回输入数组的逐元素最小值。

JAX 实现 numpy.fmin()

参数:
  • x1 (ArrayLike) – 输入数组或标量。

  • x2 (ArrayLike) – 输入数组或标量。x1 和 x2 必须具有相同的形状或广播兼容。

返回值:

包含 x1 和 x2 逐元素最小值的数组。

返回类型:

数组

注意

对于每对元素,jnp.fmin 返回
  • 如果两个元素都是有限数,则返回较小的那个。

  • 如果一个元素是 nan,则返回有限数。

  • -inf 如果一个元素是 -inf 另一个是有限值或 nan

  • inf 如果一个元素是 inf 另一个是 nan

  • nan 如果两个元素都是 nan

示例

>>> jnp.fmin(2, 3)
Array(2, dtype=int32, weak_type=True)
>>> jnp.fmin(2, jnp.array([1, 4, 2, -1]))
Array([ 1,  2,  2, -1], dtype=int32)
>>> x1 = jnp.array([1, 3, 2])
>>> x2 = jnp.array([2, 1, 4])
>>> jnp.fmin(x1, x2)
Array([1, 1, 2], dtype=int32)
>>> x3 = jnp.array([1, 5, 3])
>>> x4 = jnp.array([[2, 3, 1],
...                 [5, 6, 7]])
>>> jnp.fmin(x3, x4)
Array([[1, 3, 1],
       [1, 5, 3]], dtype=int32)
>>> nan = jnp.nan
>>> x5 = jnp.array([jnp.inf, 5, nan])
>>> x6 = jnp.array([[2, 3, nan],
...                 [nan, 6, 7]])
>>> jnp.fmin(x5, x6)
Array([[ 2.,  3., nan],
       [inf,  5.,  7.]], dtype=float32)