jax.numpy.right_shift#
- jax.numpy.right_shift(x1, x2, /)[source]#
Right shift the bits of
x1
to the amount specified inx2
.JAX implementation of
numpy.right_shift
.- Parameters:
x1 (ArrayLike) – Input array, only accepts unsigned integer subtypes
x2 (ArrayLike) – The amount of bits to shift each element in
x1
to the right, only accepts integer subtypes
- Returns:
An array-like object containing the right shifted elements of
x1
by the amount specified inx2
, with the same shape as the broadcasted shape ofx1
andx2
.- Return type:
Note
If
x1.shape != x2.shape
, they must be compatible for broadcasting to a shared shape, this shared shape will also be the shape of the output. Right shifting a scalar x1 by scalar x2 is equivalent tox1 // 2**x2
.Examples
>>> def print_binary(x): ... return [bin(int(val)) for val in x]
>>> x1 = jnp.array([1, 2, 4, 8]) >>> print_binary(x1) ['0b1', '0b10', '0b100', '0b1000'] >>> x2 = 1 >>> result = jnp.right_shift(x1, x2) >>> result Array([0, 1, 2, 4], dtype=int32) >>> print_binary(result) ['0b0', '0b1', '0b10', '0b100']
>>> x1 = 16 >>> print_binary([x1]) ['0b10000'] >>> x2 = jnp.array([1, 2, 3, 4]) >>> result = jnp.right_shift(x1, x2) >>> result Array([8, 4, 2, 1], dtype=int32) >>> print_binary(result) ['0b1000', '0b100', '0b10', '0b1']