jax.numpy.packbits#

jax.numpy.packbits(a, axis=None, bitorder='big')[source]#

Pack array of bits into a uint8 array.

JAX implementation of numpy.packbits()

Parameters:
  • a (ArrayLike) – N-dimensional array of bits to pack.

  • axis (int | None) – optional axis along which to pack bits. If not specified, a will be flattened.

  • bitorder (str) – "big" (default) or "little": specify whether the bit order is big-endian or little-endian.

Returns:

A uint8 array of packed values.

Return type:

Array

See also

Examples

Packing bits in one dimension:

>>> bits = jnp.array([0, 0, 0, 0, 0, 1, 1, 1])
>>> jnp.packbits(bits)
Array([7], dtype=uint8)
>>> 0b00000111  # equivalent bit-wise representation:
7

Optionally specifying little-endian convention:

>>> jnp.packbits(bits, bitorder="little")
Array([224], dtype=uint8)
>>> 0b11100000  # equivalent bit-wise representation
224

If the number of bits is not a multiple of 8, it will be right-padded with zeros:

>>> jnp.packbits(jnp.array([1, 0, 1]))
Array([160], dtype=uint8)
>>> jnp.packbits(jnp.array([1, 0, 1, 0, 0, 0, 0, 0]))
Array([160], dtype=uint8)

For a multi-dimensional input, bits may be packed along a specified axis:

>>> a = jnp.array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
...                [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]])
>>> vals = jnp.packbits(a, axis=1)
>>> vals
Array([[212, 150],
       [ 69, 207]], dtype=uint8)

The inverse of packbits is provided by unpackbits():

>>> jnp.unpackbits(vals, axis=1)
Array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
       [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]], dtype=uint8)