Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Lære Gradient Accumulation for Large Batch Training | Section
Pre-training Large Language Models

bookGradient Accumulation for Large Batch Training

Stryg for at vise menuen

Large batch sizes stabilize training and improve convergence – but they require more GPU memory. Gradient accumulation lets you simulate a large batch by processing several smaller mini-batches and accumulating their gradients before updating the model.

How It Works

In a standard loop, you update weights after every mini-batch. With gradient accumulation, you delay the update:

  1. Run a forward and backward pass on a mini-batch – gradients accumulate in .grad;
  2. Repeat for accumulation_steps mini-batches without calling optimizer.step();
  3. After accumulation_steps batches, call optimizer.step() and optimizer.zero_grad().

This is mathematically equivalent to training on a batch of size batch_size × accumulation_steps.

Implementation

12345678910111213141516171819202122232425262728
import torch import torch.nn as nn import torch.optim as optim # Simulating a dataloader with random batches vocab_size, d_model, seq_len = 1000, 128, 20 model = nn.Sequential(nn.Embedding(vocab_size, d_model), nn.Linear(d_model, vocab_size)) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() accumulation_steps = 4 optimizer.zero_grad() for step in range(20): inputs = torch.randint(0, vocab_size, (8, seq_len)) targets = torch.randint(0, vocab_size, (8, seq_len)) logits = model[1](model[0](inputs)) loss = criterion(logits.view(-1, vocab_size), targets.view(-1)) # Normalize so the gradient magnitude matches a full batch loss = loss / accumulation_steps loss.backward() if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() print(f"Step {step + 1} – weights updated")
copy

Dividing the loss by accumulation_steps before calling .backward() keeps gradient magnitudes consistent with what a true large-batch update would produce.

Run this locally and experiment with different accumulation_steps values. Notice that weights are only updated every 4 steps.

question mark

Which of the following scenarios is the best reason to use gradient accumulation during training?

Vælg det korrekte svar

Var alt klart?

Hvordan kan vi forbedre det?

Tak for dine kommentarer!

Sektion 1. Kapitel 7

Spørg AI

expand

Spørg AI

ChatGPT

Spørg om hvad som helst eller prøv et af de foreslåede spørgsmål for at starte vores chat

Sektion 1. Kapitel 7
some-alt