jax.nn.dot_product_attention#

jax.nn.dot_product_attention(query, key, value, bias=None, mask=None, *, scale=None, is_causal=False, query_seq_lengths=None, key_value_seq_lengths=None, local_window_size=None, implementation=None)[source]#

Scaled dot product attention function.

Computes the attention function on Query, Key, and Value tensors:

\[\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V\]

If we define logits as the output of \(QK^T\) and the probs as the output of \(softmax\).

Throughout this function, we utilize the following uppercase letters to represent the shape of array:

B = batch size
S = length of the key/value (source)
T = length of the query (target)
N = number of attention heads
H = dimensions of each attention head
K = number of key/value heads
G = number of groups, which equals to N // K
Parameters:
  • query (ArrayLike) – query array; shape (BTNH|TNH)

  • key (ArrayLike) – key array: shape (BSKH|SKH). When K equals N, multi-headed attention (MHA https://arxiv.org/abs/1706.03762) is performed. Otherwise, grouped query attention (GQA https://arxiv.org/abs/2305.13245) is performed if N is a multiple of K, and multi-query attention (MQA https://arxiv.org/abs/1911.02150) is performed if K == 1 (a special case of GQA).

  • value (ArrayLike) – value array, should have the same shape as the key array.

  • bias (ArrayLike | None | None) – optional, bias array to be added to logits; The shape must be 4D and be broadcastable to (BNTS|NTS).

  • mask (ArrayLike | None | None) – optional, mask array used to filter out logits. It is a boolean mask where True indicates the element should take part in attention. For an additive mask, users should pass it to bias. The shape must be 4D and be broadcastable to (BNTS|NTS).

  • scale (float | None | None) – scale for the logits. If None, the scale will be set to 1 divided by the square root of query’s head dimension (i.e. H).

  • is_causal (bool) – If true, causal attention will be applied. Note, some implementations like xla will generate a mask tensor and apply it to the logits to mask out the non-causal parts of the attention matrix, but other implementations like cudnn will avoid computing the non-causal regions, providing speedups.

  • query_seq_lengths (ArrayLike | None | None) – int32 array of sequence lengths for query; shape (B)

  • key_value_seq_lengths (ArrayLike | None | None) – int32 array of sequence lengths for key and value; shape (B)

  • local_window_size (int | tuple[int, int] | None | None) – Window sizes to make self attention to attend to each token’s local window. If set, this specifies the (left_window_size, right_window_size) for each token. E.g., if local_window_size == (3, 2) and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token c can attend to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as a symmetric window (window_size, window_size).

  • implementation (Literal['xla', 'cudnn'] | None | None) – A string to control which implementation backend to use. Supported strings are xla, cudnn (cuDNN flash attention). It defaults to None, which will automatically select the best available backend. Note, cudnn supports only a subset of shapes/dtypes, and an exception will be thrown if its not supported.

Returns:

An array of the attention output with the same shape as query.

Return type:

Array