jax.numpy.unpackbits#

jax.numpy.unpackbits(a, axis=None, count=None, bitorder='big')[source]#

Unpack the bits in a uint8 array.

JAX implementation of numpy.unpackbits().

Parameters:
  • a (ArrayLike) – N-dimensional array of type uint8.

  • axis (int | None) – optional axis along which to unpack. If not specified, a will be flattened

  • count (int | None) – specify the number of bits to unpack (if positive) or the number of bits to trim from the end (if negative).

  • bitorder (str) – "big" (default) or "little": specify whether the bit order is big-endian or little-endian.

Returns:

a uint8 array of unpacked bits.

Return type:

Array

See also

Examples

Unpacking bits from a scalar:

>>> jnp.unpackbits(jnp.uint8(27))  # big-endian by default
Array([0, 0, 0, 1, 1, 0, 1, 1], dtype=uint8)
>>> jnp.unpackbits(jnp.uint8(27), bitorder="little")
Array([1, 1, 0, 1, 1, 0, 0, 0], dtype=uint8)

Compare this to the Python binary representation:

>>> 0b00011011
27

Unpacking bits along an axis:

>>> vals = jnp.array([[154],
...                   [ 49]], dtype='uint8')
>>> bits = jnp.unpackbits(vals, axis=1)
>>> bits
Array([[1, 0, 0, 1, 1, 0, 1, 0],
       [0, 0, 1, 1, 0, 0, 0, 1]], dtype=uint8)

Using packbits() to invert this:

>>> jnp.packbits(bits, axis=1)
Array([[154],
       [ 49]], dtype=uint8)

The count keyword lets unpackbits serve as an inverse of packbits in cases where not all bits are present:

>>> bits = jnp.array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1])  # 11 bits
>>> vals = jnp.packbits(bits)
>>> vals
Array([219,  96], dtype=uint8)
>>> jnp.unpackbits(vals)  # 16 zero-padded bits
Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0], dtype=uint8)
>>> jnp.unpackbits(vals, count=11)  # specify 11 output bits
Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8)
>>> jnp.unpackbits(vals, count=-5)  # specify 5 bits to be trimmed
Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8)