Implementing the Pre-training Loop in PyTorch
Deslize para mostrar o menu
The pre-training loop is the core of language model training. Each iteration follows the same sequence: load a batch, compute the loss, backpropagate, update weights.
The Training Loop
123456789101112131415161718192021222324252627282930313233343536373839404142import torch import torch.nn as nn import torch.optim as optim class SimpleLanguageModel(nn.Module): def __init__(self, vocab_size, d_model): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.linear = nn.Linear(d_model, vocab_size) def forward(self, x): return self.linear(self.embedding(x)) vocab_size = 1000 d_model = 128 batch_size = 16 seq_len = 20 model = SimpleLanguageModel(vocab_size, d_model) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() # Dummy batch: input tokens and their shifted targets inputs = torch.randint(0, vocab_size, (batch_size, seq_len)) targets = torch.randint(0, vocab_size, (batch_size, seq_len)) model.train() for epoch in range(20): optimizer.zero_grad() # 1. Clear gradients logits = model(inputs) # 2. Forward pass loss = criterion( # 3. Compute CLM loss logits.view(-1, vocab_size), targets.view(-1) ) loss.backward() # 4. Backpropagate optimizer.step() # 5. Update weights print(f"Epoch {epoch + 1:02d} – loss: {loss.item():.4f}")
Each step has a fixed role. Calling optimizer.zero_grad() first is essential – without it, gradients from the previous batch accumulate and corrupt the update.
Note that in a real pre-training setup, targets is inputs shifted by one position to the right – the CLM objective from the previous chapter. The dummy random targets here are only for demonstrating the loop structure.
Run this locally and watch the loss decrease over 20 epochs. Then replace SimpleLanguageModel with the transformer you built in Course 2 to see the full pipeline in action.
Obrigado pelo seu feedback!
Pergunte à IA
Pergunte à IA
Pergunte o que quiser ou experimente uma das perguntas sugeridas para iniciar nosso bate-papo