Skip to content

Causal Language Modeling

Formula

\[ \mathcal{L}=-\sum_{t=1}^T \log P(x_t\mid x_{\lt t}) \]

Parameters

  • \(x_t\): token at position \(t\)
  • \(x_{\lt t}\): past context only
  • \(\mathcal{L}\): training loss

What it means

Causal language modeling trains a model to predict each token using only earlier tokens.

What it's used for

  • Decoder-only LLM training.
  • Autoregressive text generation.

Key properties

  • Uses causal masking in self-attention.
  • Training and generation objectives align naturally.

Common gotchas

  • "Teacher forcing" during training differs from free-running generation at inference.
  • Sequence packing can leak context if masking is wrong.

Example

Given "The sky is", the model predicts the next token distribution for words like "blue".

How to Compute (Pseudocode)

Input: token sequence batch and a decoder-only/causal LM
Output: causal LM training loss

shift targets so each position predicts the next token
run the model with a causal mask to obtain logits for all positions
compute cross-entropy loss against next-token targets
average/sum over valid positions
return loss

Complexity

  • Time: Depends on model architecture; for Transformers, training cost is dominated by masked self-attention and FFN computation over sequence length and batch size
  • Space: Depends on model size and activation storage across sequence length (attention memory can dominate)
  • Assumptions: Teacher-forced training on full sequences; exact complexity inherits the underlying model (for example, Transformer) runtime/memory behavior

See also