Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Lära Assembling a Transformer from Scratch | Section
Transformer Architecture

bookAssembling a Transformer from Scratch

Svep för att visa menyn

You now have all the components. This chapter puts them together into a full encoder-decoder transformer.

Overall Structure

The transformer consists of stacked encoder and decoder blocks. The encoder reads the source sequence and produces a contextual representation (memory). The decoder generates the target sequence one token at a time, attending to both its own previous outputs and the encoder's memory.

The data flow:

  1. Token indices → embeddings → add positional encoding;
  2. Pass through N encoder layers → produce memory;
  3. Target token indices → embeddings → add positional encoding;
  4. Pass through N decoder layers (with masked self-attention and cross-attention to memory);
  5. Linear projection → logits over vocabulary.

Implementation

import torch
import torch.nn as nn
import math


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        i = torch.arange(0, d_model, 2).float()
        angles = pos / torch.pow(10000, i / d_model)
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(angles)
        pe[:, 1::2] = torch.cos(angles)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        attn_out, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + attn_out)
        x = self.norm2(x + self.ff(x))
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, memory, tgt_mask=None):
        attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
        x = self.norm1(x + attn_out)
        cross_out, _ = self.cross_attn(x, memory, memory)
        x = self.norm2(x + cross_out)
        x = self.norm3(x + self.ff(x))
        return x


class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_heads=8, d_ff=2048, n_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.encoder = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.output_proj = nn.Linear(d_model, vocab_size)

    def encode(self, src):
        x = self.pos_enc(self.embedding(src))
        for layer in self.encoder:
            x = layer(x)
        return x

    def decode(self, tgt, memory, tgt_mask=None):
        x = self.pos_enc(self.embedding(tgt))
        for layer in self.decoder:
            x = layer(x, memory, tgt_mask)
        return x

    def forward(self, src, tgt, tgt_mask=None):
        memory = self.encode(src)
        output = self.decode(tgt, memory, tgt_mask)
        return self.output_proj(output)

Run this locally and instantiate the model with a small vocabulary to inspect its output shape:

model = Transformer(vocab_size=1000, d_model=128, n_heads=4, d_ff=512, n_layers=2)
src = torch.randint(0, 1000, (2, 10))
tgt = torch.randint(0, 1000, (2, 8))
out = model(src, tgt)
print(out.shape)  # Expected: torch.Size([2, 8, 1000])
question mark

Which of the following statements best describes how the transformer integrates its core components?

Vänligen välj det korrekta svaret

Var allt tydligt?

Hur kan vi förbättra det?

Tack för dina kommentarer!

Avsnitt 1. Kapitel 10

Fråga AI

expand

Fråga AI

ChatGPT

Fråga vad du vill eller prova någon av de föreslagna frågorna för att starta vårt samtal

Avsnitt 1. Kapitel 10
some-alt