jax.numpy.select#

jax.numpy.select(condlist, choicelist, default=0)[source]#

Select values based on a series of conditions.

JAX implementation of numpy.select(), implemented in terms of jax.lax.select_n()

Parameters:
  • condlist (Sequence[ArrayLike]) – sequence of array-like conditions. All entries must be mutually broadcast-compatible.

  • choicelist (Sequence[ArrayLike]) – sequence of array-like values to choose. Must have the same length as condlist, and all entries must be broadcast-compatible with entries of condlist.

  • default (ArrayLike) – value to return when every condition is False (default: 0).

Returns:

Array of selected values from choicelist corresponding to the first True entry in condlist at each location.

Return type:

Array

See also

Examples

>>> condlist = [
...    jnp.array([False, True, False, False]),
...    jnp.array([True, False, False, False]),
...    jnp.array([False, True, True, False]),
... ]
>>> choicelist = [
...    jnp.array([1, 2, 3, 4]),
...    jnp.array([10, 20, 30, 40]),
...    jnp.array([100, 200, 300, 400]),
... ]
>>> jnp.select(condlist, choicelist, default=0)
Array([ 10,   2, 300,   0], dtype=int32)

This is logically equivalent to the following nested where statement:

>>> default = 0
>>> jnp.where(condlist[0],
...   choicelist[0],
...   jnp.where(condlist[1],
...     choicelist[1],
...     jnp.where(condlist[2],
...       choicelist[2],
...       default)))
Array([ 10,   2, 300,   0], dtype=int32)

However, for efficiency it is implemented in terms of jax.lax.select_n().