jax.numpy.maximum

内容

jax.numpy.maximum#

jax.numpy.maximum(x, y, /)[source]#

返回输入数组的逐元素最大值。

JAX 实现 numpy.maximum

参数:
  • x (ArrayLike) – 输入数组或标量。

  • y (ArrayLike) – 输入数组或标量。 xy 应该具有相同的形状或广播兼容。

返回值:

包含 xy 的逐元素最大值的数组。

返回类型:

数组

注意

对于每一对元素,jnp.maximum 返回
  • 如果两个元素都是有限数,则返回较大的那个。

  • 如果一个元素是 nan,则返回 nan

参见

示例

输入满足 x.shape == y.shape

>>> x = jnp.array([1, -5, 3, 2])
>>> y = jnp.array([-2, 4, 7, -6])
>>> jnp.maximum(x, y)
Array([1, 4, 7, 2], dtype=int32)

具有广播兼容性的输入

>>> x1 = jnp.array([[-2, 5, 7, 4],
...                 [1, -6, 3, 8]])
>>> y1 = jnp.array([-5, 3, 6, 9])
>>> jnp.maximum(x1, y1)
Array([[-2,  5,  7,  9],
       [ 1,  3,  6,  9]], dtype=int32)

包含 nan 的输入

>>> nan = jnp.nan
>>> x2 = jnp.array([nan, -3, 9])
>>> y2 = jnp.array([[4, -2, nan],
...                 [-3, -5, 10]])
>>> jnp.maximum(x2, y2)
Array([[nan, -2., nan],
      [nan, -3., 10.]], dtype=float32)