jax.numpy.fmod#
- jax.numpy.fmod(x1, x2, /)[source]#
Calculate element-wise floating-point modulo operation.
JAX implementation of
numpy.fmod
.- Parameters:
x1 (ArrayLike) – scalar or array. Specifies the dividend.
x2 (ArrayLike) – scalar or array. Specifies the divisor.
x1
andx2
should either have same shape or be broadcast compatible.
- Returns:
An array containing the result of the element-wise floating-point modulo operation of
x1
andx2
with same sign as the elements ofx1
.- Return type:
Note
The result of
jnp.fmod
is equivalent tox1 - x2 * jnp.fix(x1 / x2)
.See also
jax.numpy.mod()
andjax.numpy.remainder()
: Returns the element-wise remainder of the division.jax.numpy.divmod()
: Calculates the integer quotient and remainder ofx1
byx2
, element-wise.
Examples
>>> x1 = jnp.array([[3, -1, 4], ... [8, 5, -2]]) >>> x2 = jnp.array([2, 3, -5]) >>> jnp.fmod(x1, x2) Array([[ 1, -1, 4], [ 0, 2, -2]], dtype=int32) >>> x1 - x2 * jnp.fix(x1 / x2) Array([[ 1., -1., 4.], [ 0., 2., -2.]], dtype=float32)