Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Lære Gradient Accumulation for Large Batch Training | Section
Pre-training Large Language Models

bookGradient Accumulation for Large Batch Training

Sveip for å vise menyen

Large batch sizes stabilize training and improve convergence – but they require more GPU memory. Gradient accumulation lets you simulate a large batch by processing several smaller mini-batches and accumulating their gradients before updating the model.

How It Works

In a standard loop, you update weights after every mini-batch. With gradient accumulation, you delay the update:

  1. Run a forward and backward pass on a mini-batch – gradients accumulate in .grad;
  2. Repeat for accumulation_steps mini-batches without calling optimizer.step();
  3. After accumulation_steps batches, call optimizer.step() and optimizer.zero_grad().

This is mathematically equivalent to training on a batch of size batch_size × accumulation_steps.

Implementation

12345678910111213141516171819202122232425262728
import torch import torch.nn as nn import torch.optim as optim # Simulating a dataloader with random batches vocab_size, d_model, seq_len = 1000, 128, 20 model = nn.Sequential(nn.Embedding(vocab_size, d_model), nn.Linear(d_model, vocab_size)) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() accumulation_steps = 4 optimizer.zero_grad() for step in range(20): inputs = torch.randint(0, vocab_size, (8, seq_len)) targets = torch.randint(0, vocab_size, (8, seq_len)) logits = model[1](model[0](inputs)) loss = criterion(logits.view(-1, vocab_size), targets.view(-1)) # Normalize so the gradient magnitude matches a full batch loss = loss / accumulation_steps loss.backward() if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() print(f"Step {step + 1} – weights updated")
copy

Dividing the loss by accumulation_steps before calling .backward() keeps gradient magnitudes consistent with what a true large-batch update would produce.

Run this locally and experiment with different accumulation_steps values. Notice that weights are only updated every 4 steps.

question mark

Which of the following scenarios is the best reason to use gradient accumulation during training?

Velg det helt riktige svaret

Alt var klart?

Hvordan kan vi forbedre det?

Takk for tilbakemeldingene dine!

Seksjon 1. Kapittel 7

Spør AI

expand

Spør AI

ChatGPT

Spør om hva du vil, eller prøv ett av de foreslåtte spørsmålene for å starte chatten vår

Seksjon 1. Kapittel 7
some-alt