jax.numpy.vecmat#
- jax.numpy.vecmat(x1, x2, /)[source]#
Batched conjugate vector-matrix product.
JAX implementation of
numpy.vecmat()
.- Parameters:
- Returns:
An array of shape
(..., N)
containing the batched conjugate vector-matrix product.- Return type:
See also
jax.numpy.linalg.vecdot()
: batched vector product.jax.numpy.matvec()
: matrix-vector product.jax.numpy.matmul()
: general matrix multiplication.
Examples
Simple vector-matrix product:
>>> x1 = jnp.array([[1, 2, 3]]) >>> x2 = jnp.array([[4, 5], ... [6, 7], ... [8, 9]]) >>> jnp.vecmat(x1, x2) Array([[40, 46]], dtype=int32)
Batched vector-matrix product:
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.vecmat(x1, x2) Array([[ 40, 46], [ 94, 109]], dtype=int32)