jax.numpy.matvec#

jax.numpy.matvec(x1, x2, /)[source]#

Batched matrix-vector product.

JAX implementation of numpy.matvec().

Parameters:
Returns:

An array of shape (..., M) containing the batched matrix-vector product.

Return type:

Array

See also

Examples

Simple matrix-vector product:

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> x2 = jnp.array([7, 8, 9])
>>> jnp.matvec(x1, x2)
Array([ 50, 122], dtype=int32)

Batched matrix-vector product:

>>> x2 = jnp.array([[7, 8, 9],
...                 [5, 6, 7]])
>>> jnp.matvec(x1, x2)
Array([[ 50, 122],
       [ 38,  92]], dtype=int32)