jax.lax.select#
- jax.lax.select(pred, on_true, on_false)[source]#
Selects between two branches based on a boolean predicate.
Wraps XLA’s Select operator.
In general
select()
leads to evaluation of both branches, although the compiler may elide computations if possible. For a similar function that usually evaluates only a single branch, seecond()
.- Parameters:
pred (ArrayLike) – boolean array
on_true (ArrayLike) – array containing entries to return where
pred
is True. Must have the same shape aspred
, and the same shape and dtype ason_false
.on_false (ArrayLike) – array containing entries to return where
pred
is False. Must have the same shape aspred
, and the same shape and dtype ason_true
.
- Returns:
array with same shape and dtype as
on_true
andon_false
.- Return type:
result