Assembling a Transformer from Scratch
Swipe um das Menü anzuzeigen
You now have all the components. This chapter puts them together into a full encoder-decoder transformer.
Overall Structure
The transformer consists of stacked encoder and decoder blocks. The encoder reads the source sequence and produces a contextual representation (memory). The decoder generates the target sequence one token at a time, attending to both its own previous outputs and the encoder's memory.
The data flow:
- Token indices → embeddings → add positional encoding;
- Pass through
Nencoder layers → producememory; - Target token indices → embeddings → add positional encoding;
- Pass through
Ndecoder layers (with masked self-attention and cross-attention tomemory); - Linear projection → logits over vocabulary.
Implementation
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pos = torch.arange(0, max_len).unsqueeze(1).float()
i = torch.arange(0, d_model, 2).float()
angles = pos / torch.pow(10000, i / d_model)
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(angles)
pe[:, 1::2] = torch.cos(angles)
self.register_buffer("pe", pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
attn_out, _ = self.self_attn(x, x, x, attn_mask=mask)
x = self.norm1(x + attn_out)
x = self.norm2(x + self.ff(x))
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, memory, tgt_mask=None):
attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
x = self.norm1(x + attn_out)
cross_out, _ = self.cross_attn(x, memory, memory)
x = self.norm2(x + cross_out)
x = self.norm3(x + self.ff(x))
return x
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model=512, n_heads=8, d_ff=2048, n_layers=6):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_enc = PositionalEncoding(d_model)
self.encoder = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
self.decoder = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
self.output_proj = nn.Linear(d_model, vocab_size)
def encode(self, src):
x = self.pos_enc(self.embedding(src))
for layer in self.encoder:
x = layer(x)
return x
def decode(self, tgt, memory, tgt_mask=None):
x = self.pos_enc(self.embedding(tgt))
for layer in self.decoder:
x = layer(x, memory, tgt_mask)
return x
def forward(self, src, tgt, tgt_mask=None):
memory = self.encode(src)
output = self.decode(tgt, memory, tgt_mask)
return self.output_proj(output)
Run this locally and instantiate the model with a small vocabulary to inspect its output shape:
model = Transformer(vocab_size=1000, d_model=128, n_heads=4, d_ff=512, n_layers=2)
src = torch.randint(0, 1000, (2, 10))
tgt = torch.randint(0, 1000, (2, 8))
out = model(src, tgt)
print(out.shape) # Expected: torch.Size([2, 8, 1000])
War alles klar?
Danke für Ihr Feedback!
Abschnitt 1. Kapitel 10
Fragen Sie AI
Fragen Sie AI
Fragen Sie alles oder probieren Sie eine der vorgeschlagenen Fragen, um unser Gespräch zu beginnen
Abschnitt 1. Kapitel 10