jax.numpy.cov#

jax.numpy.cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None)[source]#

Estimate the weighted sample covariance.

JAX implementation of numpy.cov().

The covariance \(C_{ij}\) between variable i and variable j is defined as

\[cov[X_i, X_j] = E[(X_i - E[X_i])(X_j - E[X_j])]\]

Given an array of N observations of the variables \(X_i\) and \(X_j\), this can be estimated via the sample covariance:

\[C_{ij} = \frac{1}{N - 1} \sum_{n=1}^N (X_{in} - \overline{X_i})(X_{jn} - \overline{X_j})\]

Where \(\overline{X_i} = \frac{1}{N} \sum_{k=1}^N X_{ik}\) is the mean of the observations.

Parameters:
  • m (ArrayLike) – array of shape (M, N) (if rowvar is True), or (N, M) (if rowvar is False) representing N observations of M variables. m may also be one-dimensional, representing N observations of a single variable.

  • y (ArrayLike | None) – optional set of additional observations, with the same form as m. If specified, then y is combined with m, i.e. for the default rowvar = True case, m becomes jnp.vstack([m, y]).

  • rowvar (bool) – if True (default) then each row of m represents a variable. If False, then each column represents a variable.

  • bias (bool) – if False (default) then normalize the covariance by N - 1. If True, then normalize the covariance by N

  • ddof (int | None) – specify the degrees of freedom. Defaults to 1 if bias is False, or to 0 if bias is True.

  • fweights (ArrayLike | None) – optional array of integer frequency weights of shape (N,). This is an absolute weight specifying the number of times each observation is included in the computation.

  • aweights (ArrayLike | None) – optional array of observation weights of shape (N,). This is a relative weight specifying the “importance” of each observation. In the ddof=0 case, it is equivalent to assigning probabilities to each observation.

Returns:

A covariance matrix of shape (M, M), or a scalar with shape () if M = 1.

Return type:

Array

See also

  • jax.numpy.corrcoef(): compute the correlation coefficient, a normalized version of the covariance matrix.

Examples

Consider these observations of two variables that correlate perfectly. The covariance matrix in this case is a 2x2 matrix of ones:

>>> x = jnp.array([[0, 1, 2],
...                [0, 1, 2]])
>>> jnp.cov(x)
Array([[1., 1.],
       [1., 1.]], dtype=float32)

Now consider these observations of two variables that are perfectly anti-correlated. The covariance matrix in this case has -1 in the off-diagonal:

>>> x = jnp.array([[-1,  0,  1],
...                [ 1,  0, -1]])
>>> jnp.cov(x)
Array([[ 1., -1.],
       [-1.,  1.]], dtype=float32)

Equivalently, these sequences can be specified as separate arguments, in which case they are stacked before continuing the computation.

>>> x = jnp.array([-1, 0, 1])
>>> y = jnp.array([1, 0, -1])
>>> jnp.cov(x, y)
Array([[ 1., -1.],
       [-1.,  1.]], dtype=float32)

In general, the entries of the covariance matrix may be any positive or negative real value. For example, here is the covariance of 100 points drawn from a 3-dimensional standard normal distribution:

>>> key = jax.random.key(0)
>>> x = jax.random.normal(key, shape=(3, 100))
>>> with jnp.printoptions(precision=2):
...   print(jnp.cov(x))
[[0.9  0.03 0.1 ]
 [0.03 1.   0.01]
 [0.1  0.01 0.85]]