jax.numpy.linalg.tensordot#
- jax.numpy.linalg.tensordot(x1, x2, /, *, axes=2, precision=None, preferred_element_type=None)[source]#
Compute the tensor dot product of two N-dimensional arrays.
JAX implementation of
numpy.linalg.tensordot()
.- Parameters:
x1 (ArrayLike) – N-dimensional array
x2 (ArrayLike) – M-dimensional array
axes (int | tuple[Sequence[int], Sequence[int]]) – integer or tuple of sequences of integers. If an integer k, then sum over the last k axes of
x1
and the first k axes ofx2
, in order. If a tuple, thenaxes[0]
specifies the axes ofx1
andaxes[1]
specifies the axes ofx2
.precision (PrecisionLike | None) – either
None
(default), which means the default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of two such values indicating precision ofx1
andx2
.preferred_element_type (DTypeLike | None | None) – either
None
(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Returns:
array containing the tensor dot product of the inputs
- Return type:
See also
jax.numpy.tensordot()
: equivalent API in thejax.numpy
namespace.jax.numpy.einsum()
: NumPy API for more general tensor contractions.jax.lax.dot_general()
: XLA API for more general tensor contractions.
Examples
>>> x1 = jnp.arange(24.).reshape(2, 3, 4) >>> x2 = jnp.ones((3, 4, 5)) >>> jnp.linalg.tensordot(x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result when specifying the axes as explicit sequences:
>>> jnp.linalg.tensordot(x1, x2, axes=([1, 2], [0, 1])) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Equivalent result via
einsum()
:>>> jnp.einsum('ijk,jkm->im', x1, x2) Array([[ 66., 66., 66., 66., 66.], [210., 210., 210., 210., 210.]], dtype=float32)
Setting
axes=1
for two-dimensional inputs is equivalent to a matrix multiplication:>>> x1 = jnp.array([[1, 2], ... [3, 4]]) >>> x2 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.tensordot(x1, x2, axes=1) Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32) >>> x1 @ x2 Array([[ 9, 12, 15], [19, 26, 33]], dtype=int32)
Setting
axes=0
for one-dimensional inputs is equivalent tojax.numpy.linalg.outer()
:>>> x1 = jnp.array([1, 2]) >>> x2 = jnp.array([1, 2, 3]) >>> jnp.linalg.tensordot(x1, x2, axes=0) Array([[1, 2, 3], [2, 4, 6]], dtype=int32) >>> jnp.linalg.outer(x1, x2) Array([[1, 2, 3], [2, 4, 6]], dtype=int32)