Skip to content

Self-Attention

Formula

\[ Q = XW_{Q},\quad K = XW_{K},\quad V = XW_{V} \]
\[ \operatorname{SelfAttn}(X)=\operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d_{k}}}\right)V \]

Parameters

  • \(X\): input sequence representations
  • \(W_{Q},W_{K},W_{V}\): learned projection matrices
  • \(Q,K,V\): queries, keys, values all derived from the same \(X\)

What it means

Self-attention lets each token attend to other tokens in the same sequence (and often itself) to build context-aware representations.

What it's used for

  • Transformer encoder/decoder blocks.
  • Capturing long-range dependencies without recurrence.

Key properties

  • Parallelizable across sequence positions.
  • Can represent pairwise interactions between all tokens in a layer.

Common gotchas

  • Quadratic memory/time in sequence length for standard implementations.
  • Causal masking is required for autoregressive decoding.

Example

In "The animal didn't cross because it was tired," self-attention can help "it" attend to "animal."

How to Compute (Pseudocode)

Input: sequence representations X
Output: context-aware sequence representations

Q <- X W_Q
K <- X W_K
V <- X W_V
scores <- (Q K^T) / sqrt(d_k)
apply mask if needed (for example, causal mask)
weights <- softmax(scores)
return weights V

Complexity

  • Time: Typically \(O(L^2 d)\) for sequence length \(L\) and hidden/head dimension \(d\) in dense self-attention
  • Space: Typically \(O(L^2)\) attention-score/weight memory (per head, per batch element), plus projections and outputs
  • Assumptions: Standard dense self-attention; multi-head/batch factors multiply constants and memory usage

See also