What Is Multi-Head Attention?
Swipe to show menu
Single-head attention projects the input into one representation space and computes one set of attention scores. This works, but it can only capture one type of relationship at a time – which is a significant limitation when a sequence contains multiple overlapping dependencies.
Multi-head attention runs several attention heads in parallel, each operating in its own subspace. Every head independently projects the input into its own Q, K, V representations and computes its own attention scores. The outputs are then concatenated and passed through a linear layer.
Why Multiple Heads?
Take the sentence "The black cat sat on the mat." Different relationships exist simultaneously:
"cat"→"sat": subject-verb dependency;"black"→"cat": adjective-noun modifier;"sat"→"mat": verb-object via preposition.
A single attention head would need to compress all of these into one set of weights. With multiple heads, each head can specialize — one might learn syntactic dependencies, another semantic ones, another positional patterns.
How It Fits Together
import torch
import torch.nn as nn
# Projecting the same input into 4 independent Q, K, V subspaces
d_model = 512
num_heads = 8
d_k = d_model // num_heads # 64 per head
# Each head gets its own projection matrices
W_q = nn.Linear(d_model, d_k)
W_k = nn.Linear(d_model, d_k)
W_v = nn.Linear(d_model, d_k)
# After computing attention in each head, outputs are concatenated
# and passed through a final linear layer
W_o = nn.Linear(d_model, d_model)
Each head works on a d_k-dimensional slice of the full d_model space, so the total computation stays comparable to single-head attention – just distributed across subspaces.
Run this locally to confirm that d_model // num_heads divides evenly and explore how projection dimensions change with different num_heads values.
Thanks for your feedback!
Ask AI
Ask AI
Ask anything or try one of the suggested questions to begin our chat