jax.numpy.signbit#

jax.numpy.signbit(x, /)[source]#

Return the sign bit of array elements.

JAX implementation of numpy.signbit.

Parameters:

x (ArrayLike) – input array. Complex values are not supported.

Returns:

A boolean array of the same shape as x, containing True where the sign of x is negative, and False otherwise.

Return type:

Array

See also

  • jax.numpy.sign(): return the mathematical sign of array elements, i.e. -1, 0, or +1.

Examples

signbit() on boolean values is always False:

>>> x = jnp.array([True, False])
>>> jnp.signbit(x)
Array([False, False], dtype=bool)

signbit() on integer values is equivalent to x < 0:

>>> x = jnp.array([-2, -1, 0, 1, 2])
>>> jnp.signbit(x)
Array([ True,  True, False, False, False], dtype=bool)

signbit() on floating point values returns the value of the actual sign bit from the float representation, including signed zero:

>>> x = jnp.array([-1.5, -0.0, 0.0, 1.5])
>>> jnp.signbit(x)
Array([ True, True, False, False], dtype=bool)

This also returns the sign bit for special values such as signed NaN and signed infinity:

>>> x = jnp.array([jnp.nan, -jnp.nan, jnp.inf, -jnp.inf])
>>> jnp.signbit(x)
Array([False,  True, False,  True], dtype=bool)