jax.Array.reshape#

abstract Array.reshape(*args, order='C', out_sharding=None)[source]#

Returns an array containing the same data with a new shape.

Refer to jax.numpy.reshape() for full documentation.

Parameters:
Return type:

Array