jax.numpy.searchsorted#
- jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[source]#
Perform a binary search within a sorted array.
JAX implementation of
numpy.searchsorted()
.This will return the indices within a sorted array
a
where values inv
can be inserted to maintain its sort order.- Parameters:
a (ArrayLike) – one-dimensional array, assumed to be in sorted order unless
sorter
is specified.v (ArrayLike) – N-dimensional array of query values
side (str) –
'left'
(default) or'right'
; specifies whether insertion indices will be to the left or the right in case of ties.sorter (ArrayLike | None) – optional array of indices specifying the sort order of
a
. If specified, then the algorithm assumes thata[sorter]
is in sorted order.method (str) – one of
'scan'
(default),'scan_unrolled'
,'sort'
or'compare_all'
. See Note below.
- Returns:
Array of insertion indices of shape
v.shape
.- Return type:
Note
The
method
argument controls the algorithm used to compute the insertion indices.'scan'
(the default) tends to be more performant on CPU, particularly whena
is very large.'scan_unrolled'
is more performant on GPU at the expense of additional compile time.'sort'
is often more performant on accelerator backends like GPU and TPU, particularly whenv
is very large.'compare_all'
tends to be the most performant whena
is very small.
Examples
Searching for a single value:
>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5]) >>> jnp.searchsorted(a, 2) Array(1, dtype=int32) >>> jnp.searchsorted(a, 2, side='right') Array(3, dtype=int32)
Searching for a batch of values:
>>> vals = jnp.array([0, 3, 8, 1.5, 2]) >>> jnp.searchsorted(a, vals) Array([0, 3, 7, 1, 1], dtype=int32)
Optionally, the
sorter
argument can be used to find insertion indices into an array sorted viajax.numpy.argsort()
:>>> a = jnp.array([4, 3, 5, 1, 2]) >>> sorter = jnp.argsort(a) >>> jnp.searchsorted(a, vals, sorter=sorter) Array([0, 2, 5, 1, 1], dtype=int32)
The result is equivalent to passing the sorted array:
>>> jnp.searchsorted(jnp.sort(a), vals) Array([0, 2, 5, 1, 1], dtype=int32)