jax.numpy.left_shift#

jax.numpy.left_shift(x, y, /)[源代码]#

按元素方式将 x 的位向左移动 y 指定的量。

numpy.left_shift 的 JAX 实现。

参数:
  • x (ArrayLike) – 输入数组,必须为整数类型。

  • y (ArrayLike) – 将 x 中每个元素向左移动的位数,仅接受整数子类型。xy 必须具有相同的形状或广播兼容。

返回:

一个数组,其中包含 x 中元素按 y 指定量向左移动后的结果,其形状与 xy 的广播形状相同。

返回类型:

Array

注意

在所涉及的 dtype 范围内,将 x 向左移动 y 等价于 x * (2**y)

参见

示例

>>> def print_binary(x):
...   return [bin(int(val)) for val in x]
>>> x1 = jnp.arange(5)
>>> x1
Array([0, 1, 2, 3, 4], dtype=int32)
>>> print_binary(x1)
['0b0', '0b1', '0b10', '0b11', '0b100']
>>> x2 = 1
>>> result = jnp.left_shift(x1, x2)
>>> result
Array([0, 2, 4, 6, 8], dtype=int32)
>>> print_binary(result)
['0b0', '0b10', '0b100', '0b110', '0b1000']
>>> x3 = 4
>>> print_binary([x3])
['0b100']
>>> x4 = jnp.array([1, 2, 3, 4])
>>> result1 = jnp.left_shift(x3, x4)
>>> result1
Array([ 8, 16, 32, 64], dtype=int32)
>>> print_binary(result1)
['0b1000', '0b10000', '0b100000', '0b1000000']