Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Lernen Gradient Accumulation for Large Batch Training | Section
Pre-training Large Language Models

bookGradient Accumulation for Large Batch Training

Swipe um das Menü anzuzeigen

Large batch sizes stabilize training and improve convergence – but they require more GPU memory. Gradient accumulation lets you simulate a large batch by processing several smaller mini-batches and accumulating their gradients before updating the model.

How It Works

In a standard loop, you update weights after every mini-batch. With gradient accumulation, you delay the update:

  1. Run a forward and backward pass on a mini-batch – gradients accumulate in .grad;
  2. Repeat for accumulation_steps mini-batches without calling optimizer.step();
  3. After accumulation_steps batches, call optimizer.step() and optimizer.zero_grad().

This is mathematically equivalent to training on a batch of size batch_size × accumulation_steps.

Implementation

12345678910111213141516171819202122232425262728
import torch import torch.nn as nn import torch.optim as optim # Simulating a dataloader with random batches vocab_size, d_model, seq_len = 1000, 128, 20 model = nn.Sequential(nn.Embedding(vocab_size, d_model), nn.Linear(d_model, vocab_size)) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() accumulation_steps = 4 optimizer.zero_grad() for step in range(20): inputs = torch.randint(0, vocab_size, (8, seq_len)) targets = torch.randint(0, vocab_size, (8, seq_len)) logits = model[1](model[0](inputs)) loss = criterion(logits.view(-1, vocab_size), targets.view(-1)) # Normalize so the gradient magnitude matches a full batch loss = loss / accumulation_steps loss.backward() if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() print(f"Step {step + 1} – weights updated")
copy

Dividing the loss by accumulation_steps before calling .backward() keeps gradient magnitudes consistent with what a true large-batch update would produce.

Run this locally and experiment with different accumulation_steps values. Notice that weights are only updated every 4 steps.

question mark

Which of the following scenarios is the best reason to use gradient accumulation during training?

Wählen Sie die richtige Antwort aus

War alles klar?

Wie können wir es verbessern?

Danke für Ihr Feedback!

Abschnitt 1. Kapitel 7

Fragen Sie AI

expand

Fragen Sie AI

ChatGPT

Fragen Sie alles oder probieren Sie eine der vorgeschlagenen Fragen, um unser Gespräch zu beginnen

Abschnitt 1. Kapitel 7
some-alt