jax.numpy.vecmat#

jax.numpy.vecmat(x1, x2, /)[source]#

Batched conjugate vector-matrix product.

JAX implementation of numpy.vecmat().

Parameters:
Returns:

An array of shape (..., N) containing the batched conjugate vector-matrix product.

Return type:

Array

See also

Examples

Simple vector-matrix product:

>>> x1 = jnp.array([[1, 2, 3]])
>>> x2 = jnp.array([[4, 5],
...                 [6, 7],
...                 [8, 9]])
>>> jnp.vecmat(x1, x2)
Array([[40, 46]], dtype=int32)

Batched vector-matrix product:

>>> x1 = jnp.array([[1, 2, 3],
...                 [4, 5, 6]])
>>> jnp.vecmat(x1, x2)
Array([[ 40,  46],
       [ 94, 109]], dtype=int32)