Masked Language Modeling (MLM)¶
Formula¶
\[
\mathcal{L}=-\sum_{t\in M}\log P(x_t\mid x_{\setminus M})
\]
Parameters¶
- \(M\): set of masked positions
- \(x_{\setminus M}\): observed tokens with masked positions hidden/replaced
What it means¶
MLM trains a model to reconstruct masked tokens using surrounding context.
What it's used for¶
- Encoder-style pretraining (e.g., BERT-like models).
- Learning bidirectional contextual representations.
Key properties¶
- Uses both left and right context.
- Loss is computed only on selected masked positions.
Common gotchas¶
- Pretraining masking strategy affects downstream performance.
- MLM is not directly autoregressive generation training.
Example¶
Input: "Paris is the [MASK] of France" and the model predicts "capital."
How to Compute (Pseudocode)¶
Input: token sequence batch, masking policy, encoder-style LM
Output: MLM training loss
sample/select masked positions according to the masking policy
replace/mask tokens at selected positions
run the model to obtain logits for all positions
compute loss only on masked positions
return masked-token prediction loss
Complexity¶
- Time: Depends on model architecture; usually dominated by encoder forward/backward passes over the sequence batch
- Space: Depends on model activations and sequence length; masking bookkeeping is typically small overhead
- Assumptions: BERT-style MLM workflow shown; masking ratios/policies affect constants but not the dominant model-compute term