Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Learn Gradient Accumulation for Large Batch Training | Section
Pre-training Large Language Models

bookGradient Accumulation for Large Batch Training

Swipe to show menu

Large batch sizes stabilize training and improve convergence – but they require more GPU memory. Gradient accumulation lets you simulate a large batch by processing several smaller mini-batches and accumulating their gradients before updating the model.

How It Works

In a standard loop, you update weights after every mini-batch. With gradient accumulation, you delay the update:

  1. Run a forward and backward pass on a mini-batch – gradients accumulate in .grad;
  2. Repeat for accumulation_steps mini-batches without calling optimizer.step();
  3. After accumulation_steps batches, call optimizer.step() and optimizer.zero_grad().

This is mathematically equivalent to training on a batch of size batch_size × accumulation_steps.

Implementation

12345678910111213141516171819202122232425262728
import torch import torch.nn as nn import torch.optim as optim # Simulating a dataloader with random batches vocab_size, d_model, seq_len = 1000, 128, 20 model = nn.Sequential(nn.Embedding(vocab_size, d_model), nn.Linear(d_model, vocab_size)) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() accumulation_steps = 4 optimizer.zero_grad() for step in range(20): inputs = torch.randint(0, vocab_size, (8, seq_len)) targets = torch.randint(0, vocab_size, (8, seq_len)) logits = model[1](model[0](inputs)) loss = criterion(logits.view(-1, vocab_size), targets.view(-1)) # Normalize so the gradient magnitude matches a full batch loss = loss / accumulation_steps loss.backward() if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() print(f"Step {step + 1} – weights updated")
copy

Dividing the loss by accumulation_steps before calling .backward() keeps gradient magnitudes consistent with what a true large-batch update would produce.

Run this locally and experiment with different accumulation_steps values. Notice that weights are only updated every 4 steps.

question mark

Which of the following scenarios is the best reason to use gradient accumulation during training?

Select the correct answer

Everything was clear?

How can we improve it?

Thanks for your feedback!

Section 1. Chapter 7

Ask AI

expand

Ask AI

ChatGPT

Ask anything or try one of the suggested questions to begin our chat

Section 1. Chapter 7
some-alt