jax.numpy.amax# jax.numpy.amax(a, axis=None, out=None, keepdims=False, initial=None, where=None)[source]# Alias of jax.numpy.max(). Parameters: a (ArrayLike) axis (Axis | None) out (None | None) keepdims (bool) initial (ArrayLike | None | None) where (ArrayLike | None | None) Return type: Array