Causal Language Modeling Objective Explained
Sveip for å vise menyen
Causal language modeling (CLM) is the pre-training objective behind autoregressive models like GPT. The model is trained to predict the next token given all previous tokens – and only those. It never looks ahead.
The Objective
Given a sequence of tokens $x = (x_1, x_2, \ldots, x_n)$, the model estimates the probability of the sequence as a product of conditional probabilities:
L(θ)=−t∑logP(xt∣x1,…,xt−1;θ)Minimizing this loss pushes the model to assign high probability to the correct next token at every position. The negative sign turns the likelihood maximization into a minimization problem – standard for gradient-based optimization.
Why Causal?
The "causal" in CLM means the model cannot access future tokens during training or inference. At each position, only the left context is available. This is enforced in practice by an attention mask that blocks attention to future positions – the same masked self-attention you implemented in the decoder block.
This constraint makes CLM models naturally suited for text generation: they produce one token at a time, each conditioned on everything generated so far.
Contrast with Masked Language Modeling
Masked language modeling (MLM), used in BERT, randomly masks tokens and trains the model to reconstruct them using both left and right context. This makes MLM better for understanding tasks, but unsuitable for generation – you cannot generate left-to-right if you need the right side to make predictions.
import torch
import torch.nn.functional as F
# Simulating CLM loss for a single sequence
vocab_size = 1000
seq_len = 10
# Random logits from a model (batch_size=1, seq_len, vocab_size)
logits = torch.rand(1, seq_len, vocab_size)
# Targets: each token predicts the next one
# Input: x[0..n-1], Target: x[1..n]
input_ids = torch.randint(0, vocab_size, (1, seq_len))
targets = input_ids[:, 1:] # shift right by one
logits = logits[:, :-1, :] # align with targets
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
print(f"CLM loss: {loss.item():.4f}")
Run this locally to see how the shift-by-one alignment between inputs and targets implements the CLM objective.
Takk for tilbakemeldingene dine!
Spør AI
Spør AI
Spør om hva du vil, eller prøv ett av de foreslåtte spørsmålene for å starte chatten vår