Assembling a Transformer from Scratch
Sveip for å vise menyen
You now have all the components. This chapter puts them together into a full encoder-decoder transformer.
Overall Structure
The transformer consists of stacked encoder and decoder blocks. The encoder reads the source sequence and produces a contextual representation (memory). The decoder generates the target sequence one token at a time, attending to both its own previous outputs and the encoder's memory.
The data flow:
- Token indices → embeddings → add positional encoding;
- Pass through
Nencoder layers → producememory; - Target token indices → embeddings → add positional encoding;
- Pass through
Ndecoder layers (with masked self-attention and cross-attention tomemory); - Linear projection → logits over vocabulary.
Implementation
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pos = torch.arange(0, max_len).unsqueeze(1).float()
i = torch.arange(0, d_model, 2).float()
angles = pos / torch.pow(10000, i / d_model)
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(angles)
pe[:, 1::2] = torch.cos(angles)
self.register_buffer("pe", pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
attn_out, _ = self.self_attn(x, x, x, attn_mask=mask)
x = self.norm1(x + attn_out)
x = self.norm2(x + self.ff(x))
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, memory, tgt_mask=None):
attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
x = self.norm1(x + attn_out)
cross_out, _ = self.cross_attn(x, memory, memory)
x = self.norm2(x + cross_out)
x = self.norm3(x + self.ff(x))
return x
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model=512, n_heads=8, d_ff=2048, n_layers=6):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_enc = PositionalEncoding(d_model)
self.encoder = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
self.decoder = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
self.output_proj = nn.Linear(d_model, vocab_size)
def encode(self, src):
x = self.pos_enc(self.embedding(src))
for layer in self.encoder:
x = layer(x)
return x
def decode(self, tgt, memory, tgt_mask=None):
x = self.pos_enc(self.embedding(tgt))
for layer in self.decoder:
x = layer(x, memory, tgt_mask)
return x
def forward(self, src, tgt, tgt_mask=None):
memory = self.encode(src)
output = self.decode(tgt, memory, tgt_mask)
return self.output_proj(output)
Run this locally and instantiate the model with a small vocabulary to inspect its output shape:
model = Transformer(vocab_size=1000, d_model=128, n_heads=4, d_ff=512, n_layers=2)
src = torch.randint(0, 1000, (2, 10))
tgt = torch.randint(0, 1000, (2, 8))
out = model(src, tgt)
print(out.shape) # Expected: torch.Size([2, 8, 1000])
Alt var klart?
Takk for tilbakemeldingene dine!
Seksjon 1. Kapittel 10
Spør AI
Spør AI
Spør om hva du vil, eller prøv ett av de foreslåtte spørsmålene for å starte chatten vår
Seksjon 1. Kapittel 10