jax.numpy.cov#
- jax.numpy.cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None)[source]#
Estimate the weighted sample covariance.
JAX implementation of
numpy.cov()
.The covariance \(C_{ij}\) between variable i and variable j is defined as
\[cov[X_i, X_j] = E[(X_i - E[X_i])(X_j - E[X_j])]\]Given an array of N observations of the variables \(X_i\) and \(X_j\), this can be estimated via the sample covariance:
\[C_{ij} = \frac{1}{N - 1} \sum_{n=1}^N (X_{in} - \overline{X_i})(X_{jn} - \overline{X_j})\]Where \(\overline{X_i} = \frac{1}{N} \sum_{k=1}^N X_{ik}\) is the mean of the observations.
- Parameters:
m (ArrayLike) – array of shape
(M, N)
(ifrowvar
is True), or(N, M)
(ifrowvar
is False) representingN
observations ofM
variables.m
may also be one-dimensional, representingN
observations of a single variable.y (ArrayLike | None) – optional set of additional observations, with the same form as
m
. If specified, theny
is combined withm
, i.e. for the defaultrowvar = True
case,m
becomesjnp.vstack([m, y])
.rowvar (bool) – if True (default) then each row of
m
represents a variable. If False, then each column represents a variable.bias (bool) – if False (default) then normalize the covariance by
N - 1
. If True, then normalize the covariance byN
ddof (int | None) – specify the degrees of freedom. Defaults to
1
ifbias
is False, or to0
ifbias
is True.fweights (ArrayLike | None) – optional array of integer frequency weights of shape
(N,)
. This is an absolute weight specifying the number of times each observation is included in the computation.aweights (ArrayLike | None) – optional array of observation weights of shape
(N,)
. This is a relative weight specifying the “importance” of each observation. In theddof=0
case, it is equivalent to assigning probabilities to each observation.
- Returns:
A covariance matrix of shape
(M, M)
, or a scalar with shape()
ifM = 1
.- Return type:
See also
jax.numpy.corrcoef()
: compute the correlation coefficient, a normalized version of the covariance matrix.
Examples
Consider these observations of two variables that correlate perfectly. The covariance matrix in this case is a 2x2 matrix of ones:
>>> x = jnp.array([[0, 1, 2], ... [0, 1, 2]]) >>> jnp.cov(x) Array([[1., 1.], [1., 1.]], dtype=float32)
Now consider these observations of two variables that are perfectly anti-correlated. The covariance matrix in this case has
-1
in the off-diagonal:>>> x = jnp.array([[-1, 0, 1], ... [ 1, 0, -1]]) >>> jnp.cov(x) Array([[ 1., -1.], [-1., 1.]], dtype=float32)
Equivalently, these sequences can be specified as separate arguments, in which case they are stacked before continuing the computation.
>>> x = jnp.array([-1, 0, 1]) >>> y = jnp.array([1, 0, -1]) >>> jnp.cov(x, y) Array([[ 1., -1.], [-1., 1.]], dtype=float32)
In general, the entries of the covariance matrix may be any positive or negative real value. For example, here is the covariance of 100 points drawn from a 3-dimensional standard normal distribution:
>>> key = jax.random.key(0) >>> x = jax.random.normal(key, shape=(3, 100)) >>> with jnp.printoptions(precision=2): ... print(jnp.cov(x)) [[0.9 0.03 0.1 ] [0.03 1. 0.01] [0.1 0.01 0.85]]