jax.numpy.promote_types#
- jax.numpy.promote_types(a, b)[source]#
Returns the type to which a binary operation should cast its arguments.
JAX implementation of
numpy.promote_types()
. For details of JAX’s type promotion semantics, see Type promotion semantics.- Parameters:
a (DTypeLike) – a
numpy.dtype
or a dtype specifier.b (DTypeLike) – a
numpy.dtype
or a dtype specifier.
- Returns:
A
numpy.dtype
object.- Return type:
DType
Examples
Type specifiers may be strings, dtypes, or scalar types, and the return value is always a dtype:
>>> jnp.promote_types('int32', 'float32') # strings dtype('float32') >>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes dtype('float32') >>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types dtype('float32')
Built-in scalar types (
int
,float
, orcomplex
) are treated as weakly-typed and will not change the bit width of a strongly-typed counterpart (see discussion in Type promotion semantics):>>> jnp.promote_types('uint8', int) dtype('uint8') >>> jnp.promote_types('float16', float) dtype('float16')
This differs from the NumPy version of this function, which treats built-in scalar types as equivalent to 64-bit types:
>>> import numpy >>> numpy.promote_types('uint8', int) dtype('int64') >>> numpy.promote_types('float16', float) dtype('float64')