jax.numpy.left_shift#
- jax.numpy.left_shift(x, y, /)[source]#
Shift bits of
x
to left by the amount specified iny
, element-wise.JAX implementation of
numpy.left_shift
.- Parameters:
x (ArrayLike) – Input array, must be integer-typed.
y (ArrayLike) – The amount of bits to shift each element in
x
to the left, only accepts integer subtypes.x
andy
must either have same shape or be broadcast compatible.
- Returns:
An array containing the left shifted elements of
x
by the amount specified iny
, with the same shape as the broadcasted shape ofx
andy
.- Return type:
Note
Left shifting
x
byy
is equivalent tox * (2**y)
within the bounds of the dtypes involved.See also
jax.numpy.right_shift()
: andjax.numpy.bitwise_right_shift()
: Shifts the bits ofx1
to right by the amount specified inx2
, element-wise.jax.numpy.bitwise_left_shift()
: Alias ofjax.left_shift()
.
Examples
>>> def print_binary(x): ... return [bin(int(val)) for val in x]
>>> x1 = jnp.arange(5) >>> x1 Array([0, 1, 2, 3, 4], dtype=int32) >>> print_binary(x1) ['0b0', '0b1', '0b10', '0b11', '0b100'] >>> x2 = 1 >>> result = jnp.left_shift(x1, x2) >>> result Array([0, 2, 4, 6, 8], dtype=int32) >>> print_binary(result) ['0b0', '0b10', '0b100', '0b110', '0b1000']
>>> x3 = 4 >>> print_binary([x3]) ['0b100'] >>> x4 = jnp.array([1, 2, 3, 4]) >>> result1 = jnp.left_shift(x3, x4) >>> result1 Array([ 8, 16, 32, 64], dtype=int32) >>> print_binary(result1) ['0b1000', '0b10000', '0b100000', '0b1000000']