jax.numpy.select#
- jax.numpy.select(condlist, choicelist, default=0)[source]#
Select values based on a series of conditions.
JAX implementation of
numpy.select()
, implemented in terms ofjax.lax.select_n()
- Parameters:
condlist (Sequence[ArrayLike]) – sequence of array-like conditions. All entries must be mutually broadcast-compatible.
choicelist (Sequence[ArrayLike]) – sequence of array-like values to choose. Must have the same length as
condlist
, and all entries must be broadcast-compatible with entries ofcondlist
.default (ArrayLike) – value to return when every condition is False (default: 0).
- Returns:
Array of selected values from
choicelist
corresponding to the firstTrue
entry incondlist
at each location.- Return type:
See also
jax.numpy.where()
: select between two values based on a single condition.jax.lax.select_n()
: select between N values based on an index.
Examples
>>> condlist = [ ... jnp.array([False, True, False, False]), ... jnp.array([True, False, False, False]), ... jnp.array([False, True, True, False]), ... ] >>> choicelist = [ ... jnp.array([1, 2, 3, 4]), ... jnp.array([10, 20, 30, 40]), ... jnp.array([100, 200, 300, 400]), ... ] >>> jnp.select(condlist, choicelist, default=0) Array([ 10, 2, 300, 0], dtype=int32)
This is logically equivalent to the following nested
where
statement:>>> default = 0 >>> jnp.where(condlist[0], ... choicelist[0], ... jnp.where(condlist[1], ... choicelist[1], ... jnp.where(condlist[2], ... choicelist[2], ... default))) Array([ 10, 2, 300, 0], dtype=int32)
However, for efficiency it is implemented in terms of
jax.lax.select_n()
.