jax.numpy.left_shift#

jax.numpy.left_shift(x, y, /)[源代码]#

x 的位按元素方式左移 y 指定的位数。

numpy.left_shift 的 JAX 实现。

参数:
  • x (ArrayLike) – 输入数组,必须是整数类型。

  • y (ArrayLike) – 将 x 中每个元素的位向左移动的位数,仅接受整数子类型。xy 必须具有相同的形状或广播兼容。

返回:

一个数组,其中包含 x 的元素按 y 指定的位数左移后的结果,其形状与 xy 广播后的形状相同。

返回类型:

数组

注意

在所涉及的数据类型范围内,将 x 左移 y 位等价于 x * (2**y)

另请参阅

示例

>>> def print_binary(x):
...   return [bin(int(val)) for val in x]
>>> x1 = jnp.arange(5)
>>> x1
Array([0, 1, 2, 3, 4], dtype=int32)
>>> print_binary(x1)
['0b0', '0b1', '0b10', '0b11', '0b100']
>>> x2 = 1
>>> result = jnp.left_shift(x1, x2)
>>> result
Array([0, 2, 4, 6, 8], dtype=int32)
>>> print_binary(result)
['0b0', '0b10', '0b100', '0b110', '0b1000']
>>> x3 = 4
>>> print_binary([x3])
['0b100']
>>> x4 = jnp.array([1, 2, 3, 4])
>>> result1 = jnp.left_shift(x3, x4)
>>> result1
Array([ 8, 16, 32, 64], dtype=int32)
>>> print_binary(result1)
['0b1000', '0b10000', '0b100000', '0b1000000']