jax.numpy.partition#

jax.numpy.partition(a, kth, axis=-1)[source]#

Returns a partially-sorted copy of an array.

JAX implementation of numpy.partition(). The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – array to be partitioned.

  • kth (int) – static integer index about which to partition the array.

  • axis (int) – static integer axis along which to partition the array; default is -1.

Returns:

A copy of a partitioned at the kth value along axis. The entries before kth are values smaller than take(a, kth, axis), and entries after kth are indices of values larger than take(a, kth, axis)

Return type:

Array

Note

The JAX version requires the kth argument to be a static integer rather than a general array. This is implemented via two calls to jax.lax.top_k(). If you’re only accessing the top or bottom k values of the output, it may be more efficient to call jax.lax.top_k() directly.

See also

Examples

>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> x_partitioned = jnp.partition(x, kth)
>>> x_partitioned
Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)

The result is a partially-sorted copy of the input. All values before kth are of smaller than the pivot value, and all values after kth are larger than the pivot value:

>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [9 8 7 6 5]

Notice that among smallest_values and largest_values, the returned order is arbitrary and implementation-dependent.