jax.numpy.sort_complex#
- jax.numpy.sort_complex(a)[source]#
Return a sorted copy of complex array.
JAX implementation of
numpy.sort_complex()
.Complex numbers are sorted lexicographically, meaning by their real part first, and then by their imaginary part if real parts are equal.
- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex) – input array. If dtype is not complex, the array will be upcast to complex.
- Returns:
A sorted array of the same shape and complex dtype as the input. If
a
is multi-dimensional, it is sorted along the last axis.- Return type:
See also
jax.numpy.sort()
: Return a sorted copy of an array.
Examples
>>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j]) >>> jnp.sort_complex(a) Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64)
Multi-dimensional arrays are sorted along the last axis:
>>> a = jnp.array([[5, 3, 4], ... [6, 9, 2]]) >>> jnp.sort_complex(a) Array([[3.+0.j, 4.+0.j, 5.+0.j], [2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64)