jax.numpy.array_equal#

jax.numpy.array_equal(a1, a2, equal_nan=False)[源代码]#

检查两个数组是否逐元素相等。

numpy.array_equal() 的 JAX 实现。

参数:
  • a1 (ArrayLike) – 要比较的第一个输入数组。

  • a2 (ArrayLike) – 要比较的第二个输入数组。

  • equal_nan (bool) – 布尔值。如果为 True,则 a1 中的 NaN 将被视为等于 a2 中的 NaN。默认为 False

返回:

一个布尔标量数组,指示输入数组是否逐元素相等。

返回类型:

数组

示例

>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3]))
Array(True, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, float('nan')]),
...                 jnp.array([1, 2, float('nan')]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, float('nan')]),
...                 jnp.array([1, 2, float('nan')]), equal_nan=True)
Array(True, dtype=bool)