jax.numpy.issubdtype#
- jax.numpy.issubdtype(arg1, arg2)[source]#
Return True if arg1 is equal or lower than arg2 in the type hierarchy.
JAX implementation of
numpy.issubdtype()
.The main difference in JAX’s implementation is that it properly handles dtype extensions such as
bfloat16
.- Parameters:
arg1 (DTypeLike) – dtype-like object. In typical usage, this will be a dtype specifier, such as
"float32"
(i.e. a string),np.dtype('int32')
(i.e. an instance ofnumpy.dtype
),jnp.complex64
(i.e. a JAX scalar constructor), ornp.uint8
(i.e. a NumPy scalar type).arg2 (DTypeLike) – dtype-like object. In typical usage, this will be a generic scalar type, such as
jnp.integer
,jnp.floating
, orjnp.complexfloating
.
- Returns:
True if arg1 represents a dtype that is equal or lower in the type hierarchy than arg2.
- Return type:
See also
jax.numpy.isdtype()
: similar function aligning with the array API standard.
Examples
>>> jnp.issubdtype('uint32', jnp.unsignedinteger) True >>> jnp.issubdtype(np.int32, jnp.integer) True >>> jnp.issubdtype(jnp.bfloat16, jnp.floating) True >>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating) True >>> jnp.issubdtype('complex64', jnp.integer) False
Be aware that while this is very similar to
numpy.issubdtype()
, the results of these differ in the case of JAX’s custom floating point types:>>> np.issubdtype('bfloat16', np.floating) False >>> jnp.issubdtype('bfloat16', jnp.floating) True