jax.Array.flatten#

abstract Array.flatten(order='C', *, out_sharding=None)[source]#

Flatten array into a 1-dimensional shape.

Refer to jax.numpy.ravel() for the full documentation.

Parameters:
Return type:

Array