jax.numpy.vstack#
- jax.numpy.vstack(tup, dtype=None)[source]#
Vertically stack arrays.
JAX implementation of
numpy.vstack()
.For arrays of two or more dimensions, this is equivalent to
jax.numpy.concatenate()
withaxis=0
.- Parameters:
tup (np.ndarray | Array | Sequence[ArrayLike]) – a sequence of arrays to stack; each must have the same shape along all but the first axis. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.
dtype (DTypeLike | None | None) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
- Returns:
the stacked result.
- Return type:
See also
jax.numpy.stack()
: stack along arbitrary axesjax.numpy.concatenate()
: concatenation along existing axes.jax.numpy.hstack()
: stack horizontally, i.e. along axis 1.jax.numpy.dstack()
: stack depth-wise, i.e. along axis 2.
Examples
Scalar values:
>>> jnp.vstack([1, 2, 3]) Array([[1], [2], [3]], dtype=int32, weak_type=True)
1D arrays:
>>> x = jnp.arange(4) >>> y = jnp.ones(4) >>> jnp.vstack([x, y]) Array([[0., 1., 2., 3.], [1., 1., 1., 1.]], dtype=float32)
2D arrays:
>>> x = x.reshape(1, 4) >>> y = y.reshape(1, 4) >>> jnp.vstack([x, y]) Array([[0., 1., 2., 3.], [1., 1., 1., 1.]], dtype=float32)