Skip to content

Cross-Attention

Formula

\[ Q = X_{q} W_{Q},\quad K = X_{c} W_{K},\quad V = X_{c} W_{V} \]
\[ \operatorname{CrossAttn}(X_{q},X_{c})=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_{k}}}\right)V \]

Parameters

  • \(X_{q}\): query-side representations (e.g., decoder tokens)
  • \(X_{c}\): context/source representations (e.g., encoder outputs)
  • \(Q\) from \(X_{q}\), \(K,V\) from \(X_{c}\)

What it means

Cross-attention lets one sequence (or modality) read information from another sequence/modality.

What it's used for

  • Encoder-decoder transformers (translation, summarization).
  • Multimodal models (text attending to image/audio features).

Key properties

  • Output length follows the query sequence length.
  • Keys/values provide the memory being retrieved from.

Common gotchas

  • Easy to mix up self-attention vs cross-attention by where \(Q\) and \(K,V\) come from.
  • Masking rules differ from self-attention depending on the task.

Example

In machine translation, decoder token states query encoder outputs via cross-attention to condition on the source sentence.

How to Compute (Pseudocode)

Input: query-side states X_q, context states X_c
Output: cross-attended query representations

Q <- X_q W_Q
K <- X_c W_K
V <- X_c W_V
scores <- (Q K^T) / sqrt(d_k)
weights <- softmax(scores)
return weights V

Complexity

  • Time: \(O(L_q L_c d)\)-style dense attention cost, where \(L_q\) is query length and \(L_c\) is context length
  • Space: \(O(L_q L_c)\) for attention scores/weights (per head, per batch element), plus projections/outputs
  • Assumptions: Dense cross-attention; batch and multi-head dimensions omitted for readability

See also