jax.numpy.invert#

jax.numpy.invert(x, /)[源代码]#

计算输入的按位反转。

numpy.invert() 的 JAX 实现。此函数为 JAX 数组提供 ~ 运算符的实现。

参数:

x (类数组) – 输入数组,必须为布尔型或整型。

返回:

`x 具有相同形状和 dtype 的数组,其中位被反转。

返回类型:

数组

另请参阅

示例

>>> x = jnp.arange(5, dtype='uint8')
>>> print(x)
[0 1 2 3 4]
>>> print(jnp.invert(x))
[255 254 253 252 251]

此函数实现 JAX 数组的单目运算符 ~

>>> print(~x)
[255 254 253 252 251]

invert() 对输入进行按位运算,因此通过显示按位表示可以更清楚地了解其输出的含义

>>> with jnp.printoptions(formatter={'int': lambda x: format(x, '#010b')}):
...   print(f"{x  = }")
...   print(f"{~x = }")
x  = Array([0b00000000, 0b00000001, 0b00000010, 0b00000011, 0b00000100], dtype=uint8)
~x = Array([0b11111111, 0b11111110, 0b11111101, 0b11111100, 0b11111011], dtype=uint8)

对于布尔输入,invert() 等效于 logical_not()

>>> x = jnp.array([True, False, True, True, False])
>>> jnp.invert(x)
Array([False,  True, False, False,  True], dtype=bool)