jax.numpy.partition#
- jax.numpy.partition(a, kth, axis=-1)[source]#
Returns a partially-sorted copy of an array.
JAX implementation of
numpy.partition()
. The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.- Parameters:
- Returns:
A copy of
a
partitioned at thekth
value alongaxis
. The entries beforekth
are values smaller thantake(a, kth, axis)
, and entries afterkth
are indices of values larger thantake(a, kth, axis)
- Return type:
Note
The JAX version requires the
kth
argument to be a static integer rather than a general array. This is implemented via two calls tojax.lax.top_k()
. If you’re only accessing the top or bottom k values of the output, it may be more efficient to calljax.lax.top_k()
directly.See also
jax.numpy.sort()
: full sortjax.numpy.argpartition()
: indirect partial sortjax.lax.top_k()
: directly find the top k entriesjax.lax.approx_max_k()
: compute the approximate top k entriesjax.lax.approx_min_k()
: compute the approximate bottom k entries
Examples
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> x_partitioned = jnp.partition(x, kth) >>> x_partitioned Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
The result is a partially-sorted copy of the input. All values before
kth
are of smaller than the pivot value, and all values afterkth
are larger than the pivot value:>>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [9 8 7 6 5]
Notice that among
smallest_values
andlargest_values
, the returned order is arbitrary and implementation-dependent.