jax.numpy.ravel#
- jax.numpy.ravel(a, order='C', *, out_sharding=None)[source]#
Flatten array into a 1-dimensional shape.
JAX implementation of
numpy.ravel()
, implemented in terms ofjax.lax.reshape()
.ravel(arr, order=order)
is equivalent toreshape(arr, -1, order=order)
.- Parameters:
a (ArrayLike) – array to be flattened.
order (str) –
'F'
or'C'
, specifies whether the reshape should apply column-major (fortran-style,"F"
) or row-major (C-style,"C"
) order; default is"C"
. JAX does not support order=”A” or order=”K”.
- Returns:
flattened copy of input array.
- Return type:
Notes
Unlike
numpy.ravel()
,jax.numpy.ravel()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.Array.ravel()
: equivalent functionality via an array method.jax.numpy.reshape()
: general array reshape.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]])
By default, ravel in C-style, row-major order
>>> jnp.ravel(x) Array([1, 2, 3, 4, 5, 6], dtype=int32)
Optionally ravel in Fortran-style, column-major:
>>> jnp.ravel(x, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32)
For convenience, the same functionality is available via the
jax.Array.ravel()
method:>>> x.ravel() Array([1, 2, 3, 4, 5, 6], dtype=int32)