jax.numpy.histogram2d#
- jax.numpy.histogram2d(x, y, bins=10, range=None, weights=None, density=None)[source]#
Compute a 2-dimensional histogram.
JAX implementation of
numpy.histogram2d()
.- Parameters:
x (ArrayLike) – one-dimensional array of x-values for points to be binned.
y (ArrayLike) – one-dimensional array of y-values for points to be binned.
bins (ArrayLike | list[ArrayLike]) – Specify the number of bins in the histogram (default: 10).
bins
may also be an array specifying the locations of the bin edges, or a pair of integers or pair of arrays specifying the number of bins in each dimension.range (Sequence[None | Array | Sequence[ArrayLike]] | None | None) – Pair of arrays or lists of the form
[[xmin, xmax], [ymin, ymax]]
specifying the range of the data in each dimension. If not specified, the range is inferred from the data.weights (ArrayLike | None | None) – An optional array specifying the weights of the data points. Should be the same shape as
x
andy
. If not specified, each data point is weighted equally.density (bool | None | None) – If True, return the normalized histogram in units of counts per unit area. If False (default) return the (weighted) counts per bin.
- Returns:
A tuple of arrays
(histogram, x_edges, y_edges)
, wherehistogram
contains the aggregated data, andx_edges
andy_edges
specify the boundaries of the bins.- Return type:
See also
jax.numpy.histogram()
: Compute the histogram of a 1D array.jax.numpy.histogramdd()
: Compute the histogram of an N-dimensional array.jax.numpy.histogram_bin_edges()
: Compute the bin edges for a histogram.
Examples
>>> x = jnp.array([1, 2, 3, 10, 11, 15, 19, 25]) >>> y = jnp.array([2, 5, 6, 8, 13, 16, 17, 18]) >>> counts, x_edges, y_edges = jnp.histogram2d(x, y, bins=8) >>> counts.shape (8, 8) >>> x_edges Array([ 1., 4., 7., 10., 13., 16., 19., 22., 25.], dtype=float32) >>> y_edges Array([ 2., 4., 6., 8., 10., 12., 14., 16., 18.], dtype=float32)
Specifying the bin range:
>>> counts, x_edges, y_edges = jnp.histogram2d(x, y, range=[(0, 25), (0, 25)], bins=5) >>> counts.shape (5, 5) >>> x_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32) >>> y_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32)
Specifying the bin edges explicitly:
>>> x_edges = jnp.array([0, 10, 20, 30]) >>> y_edges = jnp.array([0, 10, 20, 30]) >>> counts, _, _ = jnp.histogram2d(x, y, bins=[x_edges, y_edges]) >>> counts Array([[3, 0, 0], [1, 3, 0], [0, 1, 0]], dtype=int32)
Using
density=True
returns a normalized histogram:>>> density, x_edges, y_edges = jnp.histogram2d(x, y, density=True) >>> dx = jnp.diff(x_edges) >>> dy = jnp.diff(y_edges) >>> normed_sum = jnp.sum(density * dx[:, None] * dy[None, :]) >>> jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool)