Cross-Attention¶
Formula¶
\[
Q = X_{q} W_{Q},\quad K = X_{c} W_{K},\quad V = X_{c} W_{V}
\]
\[
\operatorname{CrossAttn}(X_{q},X_{c})=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_{k}}}\right)V
\]
Parameters¶
- \(X_{q}\): query-side representations (e.g., decoder tokens)
- \(X_{c}\): context/source representations (e.g., encoder outputs)
- \(Q\) from \(X_{q}\), \(K,V\) from \(X_{c}\)
What it means¶
Cross-attention lets one sequence (or modality) read information from another sequence/modality.
What it's used for¶
- Encoder-decoder transformers (translation, summarization).
- Multimodal models (text attending to image/audio features).
Key properties¶
- Output length follows the query sequence length.
- Keys/values provide the memory being retrieved from.
Common gotchas¶
- Easy to mix up self-attention vs cross-attention by where \(Q\) and \(K,V\) come from.
- Masking rules differ from self-attention depending on the task.
Example¶
In machine translation, decoder token states query encoder outputs via cross-attention to condition on the source sentence.
How to Compute (Pseudocode)¶
Input: query-side states X_q, context states X_c
Output: cross-attended query representations
Q <- X_q W_Q
K <- X_c W_K
V <- X_c W_V
scores <- (Q K^T) / sqrt(d_k)
weights <- softmax(scores)
return weights V
Complexity¶
- Time: \(O(L_q L_c d)\)-style dense attention cost, where \(L_q\) is query length and \(L_c\) is context length
- Space: \(O(L_q L_c)\) for attention scores/weights (per head, per batch element), plus projections/outputs
- Assumptions: Dense cross-attention; batch and multi-head dimensions omitted for readability