jax.random.multinomial#

jax.random.multinomial(key, n, p, *, shape=None, dtype=<class 'float'>, unroll=1)[source]#

Sample from a multinomial distribution.

The probability mass function is

\[f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}\]
Parameters:
  • key (Array) – PRNG key.

  • n (RealArray) – number of trials. Should have shape broadcastable to p.shape[:-1].

  • p (RealArray) – probability of each outcome, with outcomes along the last axis.

  • shape (Shape | None | None) – optional, a tuple of nonnegative integers specifying the result batch shape, that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with p.shape[:-1]. The default (None) produces a result shape equal to p.shape.

  • dtype (DTypeLikeFloat) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • unroll (int | bool) – optional, unroll parameter passed to jax.lax.scan() inside the implementation of this function.

Returns:

An array of counts for each outcome with the specified dtype and with shape

p.shape if shape is None, otherwise shape + (p.shape[-1],).