jax.numpy.issubdtype#

jax.numpy.issubdtype(arg1, arg2)[source]#

Return True if arg1 is equal or lower than arg2 in the type hierarchy.

JAX implementation of numpy.issubdtype().

The main difference in JAX’s implementation is that it properly handles dtype extensions such as bfloat16.

Parameters:
  • arg1 (DTypeLike) – dtype-like object. In typical usage, this will be a dtype specifier, such as "float32" (i.e. a string), np.dtype('int32') (i.e. an instance of numpy.dtype), jnp.complex64 (i.e. a JAX scalar constructor), or np.uint8 (i.e. a NumPy scalar type).

  • arg2 (DTypeLike) – dtype-like object. In typical usage, this will be a generic scalar type, such as jnp.integer, jnp.floating, or jnp.complexfloating.

Returns:

True if arg1 represents a dtype that is equal or lower in the type hierarchy than arg2.

Return type:

bool

See also

Examples

>>> jnp.issubdtype('uint32', jnp.unsignedinteger)
True
>>> jnp.issubdtype(np.int32, jnp.integer)
True
>>> jnp.issubdtype(jnp.bfloat16, jnp.floating)
True
>>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating)
True
>>> jnp.issubdtype('complex64', jnp.integer)
False

Be aware that while this is very similar to numpy.issubdtype(), the results of these differ in the case of JAX’s custom floating point types:

>>> np.issubdtype('bfloat16', np.floating)
False
>>> jnp.issubdtype('bfloat16', jnp.floating)
True