jax.random.ball#
- jax.random.ball(key, d, p=2, shape=(), dtype=<class 'float'>)[source]#
Sample uniformly from the unit Lp ball.
Reference: https://arxiv.org/abs/math/0503650.
- Parameters:
key (ArrayLike) – a PRNG key used as the random key.
d (int) – a nonnegative int representing the dimensionality of the ball.
p (float) – a float representing the p parameter of the Lp norm.
shape (Shape) – optional, the batch dimensions of the result. Default ().
dtype (DTypeLikeFloat) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Returns:
A random array of shape (*shape, d) and specified dtype.