Transformers Part 1 — Attention, Q/K/V, Multi-Head
Session 1 of the 48-session learning series.
Date: Wed, 2026-06-10 · Time: 18:00–20:00 IST · Track: 🧠 LLMs & Agents (LLM) · Parent 28-day topic: Day 01 · Est. read: 2 h
Why this session matters
This is Session 01 of 48 in the LLMs & Agents track. It builds on the rhythm of one focused topic, paced so you have time to actually absorb it rather than rush.
Agenda
- Why transformers won — recurrence vs attention
- The residual stream — a fixed-width bus every block reads and writes
- Self-attention as soft routing — Q, K, V and the √d_k trick
- Multi-head attention — running h heads in parallel, head specialisation
- What we'll layer on next session (RoPE, MLP, LayerNorm, KV cache)
Pre-read (skim before the session)
- Attention Is All You Need (Vaswani et al., 2017)
- The Illustrated Transformer — Jay Alammar
- A Mathematical Framework for Transformer Circuits — Anthropic
- nanoGPT model.py (Karpathy)
Deep dive
1. Why transformers won
Before 2017, sequence models meant recurrence (LSTM, GRU): read tokens one at a time, carry a hidden state. Two killers:
- Sequential dependency. You can't compute step
tbefore stept-1. GPUs are massively parallel — RNNs leave most of the silicon idle. - Long-range decay. By token 500, the gradient signal from token 1 is vanishing (or exploding). Gating helped, didn't solve it.
The transformer replaced recurrence with attention: every position computes a weighted sum over every other position in one matrix multiply. The whole sequence is processed in parallel; any position can look at any other in O(1) hops. The price: O(T²) attention cost — most modern engineering is about paying less of that price (FlashAttention, GQA, sliding windows, linear attention).
2. A concrete example
"The cat sat on the mat because it was tired."
When the model processes "it", it has to decide: does it refer to cat or mat? Attention lets it look back at every previous word, score how relevant each one is, and pull information from the most relevant (here: cat) into its own representation.
That's the one trick. Q/K/V, multi-head, RoPE, MLPs — those are engineering layered around that single idea.
3. The residual stream as a bus
Think of a decoder-only transformer as a fixed-width residual stream of shape (B, T, d_model):
B= batch sizeT= sequence lengthd_model= model width (768 for GPT-2 small, 4096 for Llama-3-8B, 8192 for Llama-3-70B, 12288 for GPT-3 175B)
Every block reads the stream, adds a small update, passes it on. Each block is a refinement, not a replacement:
tokens → embed → block 1 → block 2 → … → block N → LN → unembed → logits
▲ (stream shape: (B,T,d))
│ residual: x = x + sublayer(LN(x))
4. Self-attention — soft, learned routing
For each token we project the residual stream to three things:
- Query
Q = x · W_Q— "what am I looking for?" - Key
K = x · W_K— "what do I represent?" - Value
V = x · W_V— "what would I contribute if attended to?"
All (B, T, d_head). We compute:
attention(Q, K, V) = softmax( Q · Kᵀ / √d_k ) · V
Step by step:
Q · Kᵀ→(B, T, T)raw scores. Celli, j= how much positionishould look atj.- Divide by
√d_k— keeps logits stable so softmax doesn't saturate. - Causal mask (lower-triangular −∞) for decoder-only — position
ionly sees≤ i. - Row-wise
softmax→ each row a probability distribution. - Multiply by
V→ weighted sum of value vectors.
Output (B, T, d_head). Same per-position shape, but each position now carries information from every allowed other position.
5. The √d_k scaling — why it's there
For random Q, K with variance 1, Q · Kᵀ has variance d_k, so logits scale as √d_k. At d_k = 128 that's ±11. Softmax of ±11 collapses into near one-hot — one position takes all the mass. Then:
- Gradient vanishes (a peaked softmax has tiny gradient w.r.t. logits).
- The model can't learn to redistribute attention.
Dividing by √d_k keeps logits ≈ unit variance, softmax stays diffuse, gradients flow. Try it: at d_model = 4096 without scaling, training collapses in the first few hundred steps.
6. Multi-head attention
A single attention computes one routing pattern. Multi-head runs h heads in parallel with d_head = d_model / h:
MHA(x) = concat[head_1, …, head_h] · W_O
head_i = attention(x · W_Q_i, x · W_K_i, x · W_V_i)
Typical numbers:
| Model | d_model | h | d_head | Layers |
|---|---|---|---|---|
| GPT-2 small | 768 | 12 | 64 | 12 |
| GPT-3 175B | 12288 | 96 | 128 | 96 |
| Llama-3-8B | 4096 | 32 | 128 | 32 |
| Llama-3-70B | 8192 | 64 | 128 | 80 |
Heads specialise. Mechanistic interpretability work has documented:
- Previous-token heads — always attend to position
i-1. - Induction heads — copy patterns: after seeing
… A B … A, attend to that previousB. These are the engine of in-context learning. - Syntactic heads — subject-verb agreement, coreference, bracket matching.
You don't tell the model to have these — they emerge during training. That parallelism is the point of multi-head: many hypotheses, jointly trained.
7. Memory & compute cost
For batch B, layers L, heads h, head-dim d_head, sequence length T:
| Component | FLOPs (forward) | Memory |
|---|---|---|
| Attention scores | 2 · B · L · h · T² · d_head | B · L · h · T² (the matrix) |
| Attention · V | 2 · B · L · h · T² · d_head | — |
| MLP | 2 · B · L · T · d · 4d | B · T · 4d |
Two takeaways:
- MLPs dominate FLOPs for typical T < 4k (≈ 2/3 of compute).
- Attention dominates memory at long context. B=1, L=32, h=32, T=32k in fp16 = 64 GB just for attention matrices. This is why FlashAttention exists — tiles the computation, never materialises the full matrix.
8. From a single block to a full model
A decoder-only block is:
x = x + MHA(LayerNorm(x)) # attention sub-layer
x = x + MLP(LayerNorm(x)) # MLP sub-layer (Session 6)
Pre-norm style (LN inside the residual) trains stably for very deep stacks. The residual connection is essential — it lets gradients flow directly from logits back to the embedding through 96 layers.
9. Hands-on (last 30 min)
import torch, torch.nn as nn
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.d_head = d_model // n_heads
self.n_heads = n_heads
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
B, T, D = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
q, k, v = (t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
for t in (q, k, v))
att = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)
mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
att = att.masked_fill(mask == 0, float('-inf')).softmax(-1)
out = (att @ v).transpose(1, 2).contiguous().view(B, T, D)
return self.proj(out)
Run on (2, 16, 128) random input with d_model=128, n_heads=4. Then remove the / d_head ** 0.5 scaling and inspect att[0,0,5] — it'll collapse toward one-hot. That's the experiment.
10. What's next (Session 6)
- Positional encoding — sinusoidal vs learned vs RoPE
- The MLP — where most of the facts live
- LayerNorm — pre-norm vs post-norm
- KV cache — what actually grows in memory at inference
Reading material
- Attention Is All You Need
- The Illustrated Transformer
- A Mathematical Framework for Transformer Circuits
- nanoGPT (Karpathy)
In-depth research material
- FlashAttention
- Multi-Query Attention
- Grouped-Query Attention (Llama-2)
- Induction heads (Anthropic)
- In-context learning as gradient descent
Video reference
▶︎ 3Blue1Brown — Attention in transformers, step by step
Pick a quiet 30 minutes during this session to actually watch it. Don't multitask.
LeetCode — Two Sum
- Link: https://leetcode.com/problems/two-sum/
- Difficulty: Easy
- Why this problem: Hash-map for O(n) lookup; the canonical interview opener.
- Time-box: 30 minutes. Look up the editorial only after.
Post-session checklist
By the end of this session you should be able to:
- Draw the residual stream from token IDs → logits with N blocks and label every arrow.
- Derive Q, K, V from the residual stream and explain what each projection learns.
- State why scaled dot-product attention divides by √d_k and what breaks without it.
- List 3 head specialisations documented by mechanistic interpretability.
- Compute attention FLOPs and memory for B=1, L=32, h=32, d_head=128, T=4096.
- Implement a causal self-attention block from scratch (≤ 30 lines).
Generated from sessions_data.py + content_part*.py. To edit a video / leetcode / title, edit the data file and re-run write_sessions.py.