jax.numpy.matvec#
- jax.numpy.matvec(x1, x2, /)[source]#
Batched matrix-vector product.
JAX implementation of
numpy.matvec()
.- Parameters:
- Returns:
An array of shape
(..., M)
containing the batched matrix-vector product.- Return type:
See also
jax.numpy.linalg.vecdot()
: batched vector product.jax.numpy.vecmat()
: vector-matrix product.jax.numpy.matmul()
: general matrix multiplication.
Examples
Simple matrix-vector product:
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([7, 8, 9]) >>> jnp.matvec(x1, x2) Array([ 50, 122], dtype=int32)
Batched matrix-vector product:
>>> x2 = jnp.array([[7, 8, 9], ... [5, 6, 7]]) >>> jnp.matvec(x1, x2) Array([[ 50, 122], [ 38, 92]], dtype=int32)