jax.scipy.special.softmax#
- jax.scipy.special.softmax(x, /, *, axis=None)[source]#
Softmax function.
JAX implementation of
scipy.special.softmax()
.Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axis
sum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters:
- Returns:
An array of the same shape as
x
.- Return type:
Note
If any input values are
+inf
, the result will be allNaN
: this reflects the fact thatinf / inf
is not well-defined in the context of floating-point math.See also