Implementing the Pre-training Loop in PyTorch
Свайпніть щоб показати меню
The pre-training loop is the core of language model training. Each iteration follows the same sequence: load a batch, compute the loss, backpropagate, update weights.
The Training Loop
123456789101112131415161718192021222324252627282930313233343536373839404142import torch import torch.nn as nn import torch.optim as optim class SimpleLanguageModel(nn.Module): def __init__(self, vocab_size, d_model): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.linear = nn.Linear(d_model, vocab_size) def forward(self, x): return self.linear(self.embedding(x)) vocab_size = 1000 d_model = 128 batch_size = 16 seq_len = 20 model = SimpleLanguageModel(vocab_size, d_model) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() # Dummy batch: input tokens and their shifted targets inputs = torch.randint(0, vocab_size, (batch_size, seq_len)) targets = torch.randint(0, vocab_size, (batch_size, seq_len)) model.train() for epoch in range(20): optimizer.zero_grad() # 1. Clear gradients logits = model(inputs) # 2. Forward pass loss = criterion( # 3. Compute CLM loss logits.view(-1, vocab_size), targets.view(-1) ) loss.backward() # 4. Backpropagate optimizer.step() # 5. Update weights print(f"Epoch {epoch + 1:02d} – loss: {loss.item():.4f}")
Each step has a fixed role. Calling optimizer.zero_grad() first is essential – without it, gradients from the previous batch accumulate and corrupt the update.
Note that in a real pre-training setup, targets is inputs shifted by one position to the right – the CLM objective from the previous chapter. The dummy random targets here are only for demonstrating the loop structure.
Run this locally and watch the loss decrease over 20 epochs. Then replace SimpleLanguageModel with the transformer you built in Course 2 to see the full pipeline in action.
Дякуємо за ваш відгук!
Запитати АІ
Запитати АІ
Запитайте про що завгодно або спробуйте одне із запропонованих запитань, щоб почати наш чат