jax.nn.initializers.lecun_normal#
- jax.nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)[source]#
Builds a Lecun normal initializer.
A Lecun normal initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 1.0
,mode="fan_in"
, anddistribution="truncated_normal"
.- Parameters:
in_axis (int | Sequence[int]) – axis or sequence of axes of the input dimension in the weights array.
out_axis (int | Sequence[int]) – axis or sequence of axes of the output dimension in the weights array.
batch_axis (int | Sequence[int]) – axis or sequence of axes in the weight array that should be ignored.
dtype (DTypeLikeInexact) – the dtype of the weights.
- Returns:
An initializer.
- Return type:
Initializer
Examples:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() >>> initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[ 0.46700746, 0.8414632 , 0.8518669 ], [-0.61677957, -0.67402434, 0.09683388]], dtype=float32)