jax.numpy.logaddexp#

jax.numpy.logaddexp = <jnp.ufunc 'logaddexp'>#

Compute log(exp(x1) + exp(x2)) avoiding overflow.

JAX implementation of numpy.logaddexp

Parameters:
  • x1 – input array

  • x2 – input array

  • args (ArrayLike)

  • out (None)

  • where (None)

Returns:

array containing the result.

Return type:

Any

Examples:

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> result1 = jnp.logaddexp(x1, x2)
>>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2))
>>> print(jnp.allclose(result1, result2))
True