jax.random.f#
- jax.random.f(key, dfnum, dfden, shape=None, dtype=<class 'float'>)[source]#
Sample F-distribution random values with given shape and float dtype.
The values are distributed according to the probability density function:
\[f(x; \nu_1, \nu_2) \propto x^{\nu_1/2 - 1}\left(1 + \frac{\nu_1}{\nu_2}x\right)^{ -(\nu_1 + \nu_2) / 2}\]on the domain \(0 < x < \infty\). Here \(\nu_1\) is the degrees of freedom of the numerator (
dfnum
), and \(\nu_2\) is the degrees of freedom of the denominator (dfden
).- Parameters:
key (ArrayLike) – a PRNG key used as the random key.
dfnum (RealArray) – a float or array of floats broadcast-compatible with
shape
representing the numerator’sdf
of the distribution.dfden (RealArray) – a float or array of floats broadcast-compatible with
shape
representing the denominator’sdf
of the distribution.shape (Shape | None | None) – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with
dfnum
anddfden
. The default (None) produces a result shape equal todfnum.shape
, anddfden.shape
.dtype (DTypeLikeFloat) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns:
A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else bydf.shape
.- Return type: