jax.numpy.packbits#
- jax.numpy.packbits(a, axis=None, bitorder='big')[source]#
Pack array of bits into a uint8 array.
JAX implementation of
numpy.packbits()
- Parameters:
- Returns:
A uint8 array of packed values.
- Return type:
See also
jax.numpy.unpackbits()
: inverse ofpackbits
.
Examples
Packing bits in one dimension:
>>> bits = jnp.array([0, 0, 0, 0, 0, 1, 1, 1]) >>> jnp.packbits(bits) Array([7], dtype=uint8) >>> 0b00000111 # equivalent bit-wise representation: 7
Optionally specifying little-endian convention:
>>> jnp.packbits(bits, bitorder="little") Array([224], dtype=uint8) >>> 0b11100000 # equivalent bit-wise representation 224
If the number of bits is not a multiple of 8, it will be right-padded with zeros:
>>> jnp.packbits(jnp.array([1, 0, 1])) Array([160], dtype=uint8) >>> jnp.packbits(jnp.array([1, 0, 1, 0, 0, 0, 0, 0])) Array([160], dtype=uint8)
For a multi-dimensional input, bits may be packed along a specified axis:
>>> a = jnp.array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0], ... [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]]) >>> vals = jnp.packbits(a, axis=1) >>> vals Array([[212, 150], [ 69, 207]], dtype=uint8)
The inverse of
packbits
is provided byunpackbits()
:>>> jnp.unpackbits(vals, axis=1) Array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0], [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]], dtype=uint8)