jax.linear_transpose#
- jax.linear_transpose(fun, *primals, reduce_axes=())[source]#
Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to
vjp()
, but avoids the overhead of computing the forward pass.The outputs of the transposed function will always have the exact same dtypes as
primals
, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes inprimals
that match the full range of desired outputs from the transposed function. Integer dtypes are not supported.- Parameters:
fun (Callable) – the linear function to be transposed.
*primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of
fun(*primals)
. These arguments may be real scalars/ndarrays, but that is not required: only theshape
anddtype
attributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.)
- Returns:
A callable that calculates the transpose of
fun
. Valid input into this function must have the same shape/dtypes/structure as the result offun(*primals)
. Output will be a tuple, with the same shape/dtypes/structure asprimals
.- Return type:
Callable
>>> import jax >>> >>> f = lambda x, y: 0.5 * x - 0.5 * y >>> scalar = jax.ShapeDtypeStruct(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (Array(0.5, dtype=float32), Array(-0.5, dtype=float32))