jax.numpy.size#

jax.numpy.size(a, axis=None)[source]#

Return number of elements along a given axis.

JAX implementation of numpy.size(). Unlike np.size, this function raises a TypeError if the input is a collection such as a list or tuple.

Parameters:
  • a (ArrayLike | SupportsSize | SupportsShape) – array-like object, or any object with a size attribute when axis is not specified, or with a shape attribute when axis is specified.

  • axis (int | None | None) – optional integer along which to count elements. By default, return the total number of elements.

Returns:

An integer specifying the number of elements in a.

Return type:

int

Examples

Size for arrays:

>>> x = jnp.arange(10)
>>> jnp.size(x)
10
>>> y = jnp.ones((2, 3))
>>> jnp.size(y)
6
>>> jnp.size(y, axis=1)
3

This also works for scalars:

>>> jnp.size(3.14)
1

For arrays, this can also be accessed via the jax.Array.size property:

>>> y.size
6