Gradient Accumulation for Large Batch Training
Glissez pour afficher le menu
Large batch sizes stabilize training and improve convergence – but they require more GPU memory. Gradient accumulation lets you simulate a large batch by processing several smaller mini-batches and accumulating their gradients before updating the model.
How It Works
In a standard loop, you update weights after every mini-batch. With gradient accumulation, you delay the update:
- Run a forward and backward pass on a mini-batch – gradients accumulate in
.grad; - Repeat for
accumulation_stepsmini-batches without callingoptimizer.step(); - After
accumulation_stepsbatches, calloptimizer.step()andoptimizer.zero_grad().
This is mathematically equivalent to training on a batch of size batch_size × accumulation_steps.
Implementation
12345678910111213141516171819202122232425262728import torch import torch.nn as nn import torch.optim as optim # Simulating a dataloader with random batches vocab_size, d_model, seq_len = 1000, 128, 20 model = nn.Sequential(nn.Embedding(vocab_size, d_model), nn.Linear(d_model, vocab_size)) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() accumulation_steps = 4 optimizer.zero_grad() for step in range(20): inputs = torch.randint(0, vocab_size, (8, seq_len)) targets = torch.randint(0, vocab_size, (8, seq_len)) logits = model[1](model[0](inputs)) loss = criterion(logits.view(-1, vocab_size), targets.view(-1)) # Normalize so the gradient magnitude matches a full batch loss = loss / accumulation_steps loss.backward() if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() print(f"Step {step + 1} – weights updated")
Dividing the loss by accumulation_steps before calling .backward() keeps gradient magnitudes consistent with what a true large-batch update would produce.
Run this locally and experiment with different accumulation_steps values. Notice that weights are only updated every 4 steps.
Merci pour vos commentaires !
Demandez à l'IA
Demandez à l'IA
Posez n'importe quelle question ou essayez l'une des questions suggérées pour commencer notre discussion