jax.nn.identity#

jax.nn.identity(x)[source]#

Identity activation function.

Returns the argument unmodified.

Parameters:

x (ArrayLike) – input array

Returns:

The argument x unmodified.

Return type:

Array

Examples

>>> jax.nn.identity(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([-2. , -1. , -0.5, 0. , 0.5, 1. , 2. ], dtype=float32)