jax.numpy.histogram#

jax.numpy.histogram(a, bins=10, range=None, weights=None, density=None)[source]#

Compute a 1-dimensional histogram.

JAX implementation of numpy.histogram().

Parameters:
  • a (ArrayLike) – array of values to be binned. May be any size or dimension.

  • bins (ArrayLike) – Specify the number of bins in the histogram (default: 10). bins may also be an array specifying the locations of the bin edges.

  • range (Sequence[ArrayLike] | None | None) – tuple of scalars. Specifies the range of the data. If not specified, the range is inferred from the data.

  • weights (ArrayLike | None | None) – An optional array specifying the weights of the data points. Should be broadcast-compatible with a. If not specified, each data point is weighted equally.

  • density (bool | None | None) – If True, return the normalized histogram in units of counts per unit length. If False (default) return the (weighted) counts per bin.

Returns:

A tuple of arrays (histogram, bin_edges), where histogram contains the aggregated data, and bin_edges specifies the boundaries of the bins.

Return type:

tuple[Array, Array]

See also

Examples

>>> a = jnp.array([1, 2, 3, 10, 11, 15, 19, 25])
>>> counts, bin_edges = jnp.histogram(a, bins=8)
>>> print(counts)
[3. 0. 0. 2. 1. 0. 1. 1.]
>>> print(bin_edges)
[ 1.  4.  7. 10. 13. 16. 19. 22. 25.]

Specifying the bin range:

>>> counts, bin_edges = jnp.histogram(a, range=(0, 25), bins=5)
>>> print(counts)
[3. 0. 2. 2. 1.]
>>> print(bin_edges)
[ 0.  5. 10. 15. 20. 25.]

Specifying the bin edges explicitly:

>>> bin_edges = jnp.array([0, 10, 20, 30])
>>> counts, _ = jnp.histogram(a, bins=bin_edges)
>>> print(counts)
[3. 4. 1.]

Using density=True returns a normalized histogram:

>>> density, bin_edges = jnp.histogram(a, density=True)
>>> dx = jnp.diff(bin_edges)
>>> normed_sum = jnp.sum(density * dx)
>>> jnp.allclose(normed_sum, 1.0)
Array(True, dtype=bool)