jax.numpy.array_equal#
- jax.numpy.array_equal(a1, a2, equal_nan=False)[source]#
检查两个数组是否按元素相等。
JAX 实现
numpy.array_equal()
.- 参数:
a1 (ArrayLike) – 要比较的第一个输入数组。
a2 (ArrayLike) – 要比较的第二个输入数组。
equal_nan (bool) – 布尔值。如果为
True
,则a1
中的 NaN 将被视为与a2
中的 NaN 相等。默认值为False
。
- 返回值:
布尔标量数组,指示输入数组是否按元素相等。
- 返回类型:
示例
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) Array(True, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')]), equal_nan=True) Array(True, dtype=bool)