jax.scipy.stats.multivariate_normal.pdf#
- jax.scipy.stats.multivariate_normal.pdf(x, mean, cov)[source]#
Multivariate normal probability distribution function.
JAX implementation of
scipy.stats.multivariate_normal
pdf
.The multivariate normal PDF is defined as
\[f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)\]where \(\mu\) is the
mean
, \(\Sigma\) is the covariance matrix (cov
), and \(k\) is the rank of \(\Sigma\).- Parameters:
x (Array | ndarray | bool | number | bool | int | float | complex) – arraylike, value at which to evaluate the PDF
mean (Array | ndarray | bool | number | bool | int | float | complex) – arraylike, centroid of distribution
cov (Array | ndarray | bool | number | bool | int | float | complex) – arraylike, covariance matrix of distribution
allow_singular – not supported
- Returns:
array of pdf values.
- Return type: