Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Oppiskele Gradient Accumulation for Large Batch Training | Section
Pre-training Large Language Models

bookGradient Accumulation for Large Batch Training

Pyyhkäise näyttääksesi valikon

Large batch sizes stabilize training and improve convergence – but they require more GPU memory. Gradient accumulation lets you simulate a large batch by processing several smaller mini-batches and accumulating their gradients before updating the model.

How It Works

In a standard loop, you update weights after every mini-batch. With gradient accumulation, you delay the update:

  1. Run a forward and backward pass on a mini-batch – gradients accumulate in .grad;
  2. Repeat for accumulation_steps mini-batches without calling optimizer.step();
  3. After accumulation_steps batches, call optimizer.step() and optimizer.zero_grad().

This is mathematically equivalent to training on a batch of size batch_size × accumulation_steps.

Implementation

12345678910111213141516171819202122232425262728
import torch import torch.nn as nn import torch.optim as optim # Simulating a dataloader with random batches vocab_size, d_model, seq_len = 1000, 128, 20 model = nn.Sequential(nn.Embedding(vocab_size, d_model), nn.Linear(d_model, vocab_size)) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() accumulation_steps = 4 optimizer.zero_grad() for step in range(20): inputs = torch.randint(0, vocab_size, (8, seq_len)) targets = torch.randint(0, vocab_size, (8, seq_len)) logits = model[1](model[0](inputs)) loss = criterion(logits.view(-1, vocab_size), targets.view(-1)) # Normalize so the gradient magnitude matches a full batch loss = loss / accumulation_steps loss.backward() if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() print(f"Step {step + 1} – weights updated")
copy

Dividing the loss by accumulation_steps before calling .backward() keeps gradient magnitudes consistent with what a true large-batch update would produce.

Run this locally and experiment with different accumulation_steps values. Notice that weights are only updated every 4 steps.

question mark

Which of the following scenarios is the best reason to use gradient accumulation during training?

Valitse oikea vastaus

Oliko kaikki selvää?

Miten voimme parantaa sitä?

Kiitos palautteestasi!

Osio 1. Luku 7

Kysy tekoälyä

expand

Kysy tekoälyä

ChatGPT

Kysy mitä tahansa tai kokeile jotakin ehdotetuista kysymyksistä aloittaaksesi keskustelumme

Osio 1. Luku 7
some-alt