jax.numpy.logaddexp#
- jax.numpy.logaddexp = <jnp.ufunc 'logaddexp'>#
Compute
log(exp(x1) + exp(x2))
avoiding overflow.JAX implementation of
numpy.logaddexp
- Parameters:
x1 – input array
x2 – input array
args (ArrayLike)
out (None)
where (None)
- Returns:
array containing the result.
- Return type:
Any
Examples:
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> result1 = jnp.logaddexp(x1, x2) >>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2)) >>> print(jnp.allclose(result1, result2)) True