jax.numpy.array_equal#
- jax.numpy.array_equal(a1, a2, equal_nan=False)[源代码]#
检查两个数组是否逐元素相等。
numpy.array_equal()
的 JAX 实现。- 参数:
a1 (ArrayLike) – 要比较的第一个输入数组。
a2 (ArrayLike) – 要比较的第二个输入数组。
equal_nan (bool) – 布尔值。如果为
True
,则a1
中的 NaN 将被视为等于a2
中的 NaN。默认为False
。
- 返回:
一个布尔标量数组,指示输入数组是否逐元素相等。
- 返回类型:
示例
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) Array(True, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')]), equal_nan=True) Array(True, dtype=bool)