jax.nn.initializers.variance_scaling#
- jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)[source]#
Initializer that adapts its scale to the shape of the weights tensor.
With
distribution="truncated_normal"
ordistribution="normal"
, samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of \(\sqrt{\frac{scale}{n}}\), where n is, for eachmode
:"fan_in"
: the number of inputs"fan_out"
: the number of outputs"fan_avg"
: the arithmetic average of the numbers of inputs and outputs"fan_geo_avg"
: the geometric average of the numbers of inputs and outputs
This initializer can be configured with
in_axis
,out_axis
, andbatch_axis
to work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).With
distribution="truncated_normal"
, the absolute values of the samples are truncated at 2 standard deviations before scaling.With
distribution="uniform"
, samples are drawn from:a uniform interval, if dtype is real, or
a uniform disk, if dtype is complex,
with a mean of zero and a standard deviation of \(\sqrt{\frac{scale}{n}}\) where n is defined above.
- Parameters:
scale (RealNumeric) – scaling factor (positive float).
mode (Literal['fan_in'] | Literal['fan_out'] | Literal['fan_avg'] | Literal['fan_geo_avg']) – one of
"fan_in"
,"fan_out"
,"fan_avg"
, and"fan_geo_avg"
.distribution (Literal['truncated_normal'] | Literal['normal'] | Literal['uniform']) – random distribution to use. One of
"truncated_normal"
,"normal"
and"uniform"
.in_axis (int | Sequence[int]) – axis or sequence of axes of the input dimension in the weights array.
out_axis (int | Sequence[int]) – axis or sequence of axes of the output dimension in the weights array.
batch_axis (int | Sequence[int]) – axis or sequence of axes in the weight array that should be ignored.
dtype (DTypeLikeInexact) – the dtype of the weights.
- Return type:
Initializer