jax.numpy.size#
- jax.numpy.size(a, axis=None)[source]#
Return number of elements along a given axis.
JAX implementation of
numpy.size()
. Unlikenp.size
, this function raises aTypeError
if the input is a collection such as a list or tuple.- Parameters:
a (ArrayLike | SupportsSize | SupportsShape) – array-like object, or any object with a
size
attribute whenaxis
is not specified, or with ashape
attribute whenaxis
is specified.axis (int | None | None) – optional integer along which to count elements. By default, return the total number of elements.
- Returns:
An integer specifying the number of elements in
a
.- Return type:
Examples
Size for arrays:
>>> x = jnp.arange(10) >>> jnp.size(x) 10 >>> y = jnp.ones((2, 3)) >>> jnp.size(y) 6 >>> jnp.size(y, axis=1) 3
This also works for scalars:
>>> jnp.size(3.14) 1
For arrays, this can also be accessed via the
jax.Array.size
property:>>> y.size 6