jax.numpy.left_shift#

jax.numpy.left_shift(x, y, /)[source]#

Shift bits of x to left by the amount specified in y, element-wise.

JAX implementation of numpy.left_shift.

Parameters:
  • x (ArrayLike) – Input array, must be integer-typed.

  • y (ArrayLike) – The amount of bits to shift each element in x to the left, only accepts integer subtypes. x and y must either have same shape or be broadcast compatible.

Returns:

An array containing the left shifted elements of x by the amount specified in y, with the same shape as the broadcasted shape of x and y.

Return type:

Array

Note

Left shifting x by y is equivalent to x * (2**y) within the bounds of the dtypes involved.

See also

Examples

>>> def print_binary(x):
...   return [bin(int(val)) for val in x]
>>> x1 = jnp.arange(5)
>>> x1
Array([0, 1, 2, 3, 4], dtype=int32)
>>> print_binary(x1)
['0b0', '0b1', '0b10', '0b11', '0b100']
>>> x2 = 1
>>> result = jnp.left_shift(x1, x2)
>>> result
Array([0, 2, 4, 6, 8], dtype=int32)
>>> print_binary(result)
['0b0', '0b10', '0b100', '0b110', '0b1000']
>>> x3 = 4
>>> print_binary([x3])
['0b100']
>>> x4 = jnp.array([1, 2, 3, 4])
>>> result1 = jnp.left_shift(x3, x4)
>>> result1
Array([ 8, 16, 32, 64], dtype=int32)
>>> print_binary(result1)
['0b1000', '0b10000', '0b100000', '0b1000000']