jax.numpy.subtract#
- jax.numpy.subtract = <jnp.ufunc 'subtract'>#
Subtract two arrays element-wise.
JAX implementation of
numpy.subtract
. This is a universal function, and supports the additional APIs described atjax.numpy.ufunc
. This function provides the implementation of the-
operator for JAX arrays.- Parameters:
x – arrays to subtract. Must be broadcastable to a common shape.
y – arrays to subtract. Must be broadcastable to a common shape.
args (ArrayLike)
out (None)
where (None)
- Returns:
Array containing the result of the element-wise subtraction.
- Return type:
Any
Examples
Calling
subtract
explicitly:>>> x = jnp.arange(4) >>> jnp.subtract(x, 10) Array([-10, -9, -8, -7], dtype=int32)
Calling
subtract
via the-
operator:>>> x - 10 Array([-10, -9, -8, -7], dtype=int32)