jax.random.multinomial#
- jax.random.multinomial(key, n, p, *, shape=None, dtype=<class 'float'>, unroll=1)[source]#
Sample from a multinomial distribution.
The probability mass function is
\[f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}\]- Parameters:
key (Array) – PRNG key.
n (RealArray) – number of trials. Should have shape broadcastable to
p.shape[:-1]
.p (RealArray) – probability of each outcome, with outcomes along the last axis.
shape (Shape | None | None) – optional, a tuple of nonnegative integers specifying the result batch shape, that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with
p.shape[:-1]
. The default (None) produces a result shape equal top.shape
.dtype (DTypeLikeFloat) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
unroll (int | bool) – optional, unroll parameter passed to
jax.lax.scan()
inside the implementation of this function.
- Returns:
- An array of counts for each outcome with the specified dtype and with shape
p.shape
ifshape
is None, otherwiseshape + (p.shape[-1],)
.