jax.numpy.promote_types#

jax.numpy.promote_types(a, b)[source]#

Returns the type to which a binary operation should cast its arguments.

JAX implementation of numpy.promote_types(). For details of JAX’s type promotion semantics, see Type promotion semantics.

Parameters:
  • a (DTypeLike) – a numpy.dtype or a dtype specifier.

  • b (DTypeLike) – a numpy.dtype or a dtype specifier.

Returns:

A numpy.dtype object.

Return type:

DType

Examples

Type specifiers may be strings, dtypes, or scalar types, and the return value is always a dtype:

>>> jnp.promote_types('int32', 'float32')  # strings
dtype('float32')
>>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32'))  # dtypes
dtype('float32')
>>> jnp.promote_types(jnp.int32, jnp.float32)  # scalar types
dtype('float32')

Built-in scalar types (int, float, or complex) are treated as weakly-typed and will not change the bit width of a strongly-typed counterpart (see discussion in Type promotion semantics):

>>> jnp.promote_types('uint8', int)
dtype('uint8')
>>> jnp.promote_types('float16', float)
dtype('float16')

This differs from the NumPy version of this function, which treats built-in scalar types as equivalent to 64-bit types:

>>> import numpy
>>> numpy.promote_types('uint8', int)
dtype('int64')
>>> numpy.promote_types('float16', float)
dtype('float64')