Search Tech Journey

Find topics, journeys and posts

back to blog
ai mlintermediate 12m2026-06-09

Transformers Part 1 — Attention, Q/K/V, Multi-Head

Session 1 of the 48-session learning series.

Date: Wed, 2026-06-10 · Time: 18:00–20:00 IST · Track: 🧠 LLMs & Agents (LLM) · Parent 28-day topic: Day 01 · Est. read: 2 h

Why this session matters

This is Session 01 of 48 in the LLMs & Agents track. It builds on the rhythm of one focused topic, paced so you have time to actually absorb it rather than rush.

Agenda

  • Why transformers won — recurrence vs attention
  • The residual stream — a fixed-width bus every block reads and writes
  • Self-attention as soft routing — Q, K, V and the √d_k trick
  • Multi-head attention — running h heads in parallel, head specialisation
  • What we'll layer on next session (RoPE, MLP, LayerNorm, KV cache)

Pre-read (skim before the session)

Deep dive

1. Why transformers won

Before 2017, sequence models meant recurrence (LSTM, GRU): read tokens one at a time, carry a hidden state. Two killers:

  1. Sequential dependency. You can't compute step t before step t-1. GPUs are massively parallel — RNNs leave most of the silicon idle.
  2. Long-range decay. By token 500, the gradient signal from token 1 is vanishing (or exploding). Gating helped, didn't solve it.

The transformer replaced recurrence with attention: every position computes a weighted sum over every other position in one matrix multiply. The whole sequence is processed in parallel; any position can look at any other in O(1) hops. The price: O(T²) attention cost — most modern engineering is about paying less of that price (FlashAttention, GQA, sliding windows, linear attention).

2. A concrete example

"The cat sat on the mat because it was tired."

When the model processes "it", it has to decide: does it refer to cat or mat? Attention lets it look back at every previous word, score how relevant each one is, and pull information from the most relevant (here: cat) into its own representation.

That's the one trick. Q/K/V, multi-head, RoPE, MLPs — those are engineering layered around that single idea.

3. The residual stream as a bus

Think of a decoder-only transformer as a fixed-width residual stream of shape (B, T, d_model):

  • B = batch size
  • T = sequence length
  • d_model = model width (768 for GPT-2 small, 4096 for Llama-3-8B, 8192 for Llama-3-70B, 12288 for GPT-3 175B)

Every block reads the stream, adds a small update, passes it on. Each block is a refinement, not a replacement:

   tokens  →  embed  →  block 1  →  block 2  →  …  →  block N  →  LN  →  unembed  →  logits
                          ▲                                          (stream shape: (B,T,d))
                          │ residual: x = x + sublayer(LN(x))

4. Self-attention — soft, learned routing

For each token we project the residual stream to three things:

  • Query Q = x · W_Q — "what am I looking for?"
  • Key K = x · W_K — "what do I represent?"
  • Value V = x · W_V — "what would I contribute if attended to?"

All (B, T, d_head). We compute:

attention(Q, K, V) = softmax( Q · Kᵀ / √d_k ) · V

Step by step:

  1. Q · Kᵀ(B, T, T) raw scores. Cell i, j = how much position i should look at j.
  2. Divide by √d_k — keeps logits stable so softmax doesn't saturate.
  3. Causal mask (lower-triangular −∞) for decoder-only — position i only sees ≤ i.
  4. Row-wise softmax → each row a probability distribution.
  5. Multiply by V → weighted sum of value vectors.

Output (B, T, d_head). Same per-position shape, but each position now carries information from every allowed other position.

5. The √d_k scaling — why it's there

For random Q, K with variance 1, Q · Kᵀ has variance d_k, so logits scale as √d_k. At d_k = 128 that's ±11. Softmax of ±11 collapses into near one-hot — one position takes all the mass. Then:

  • Gradient vanishes (a peaked softmax has tiny gradient w.r.t. logits).
  • The model can't learn to redistribute attention.

Dividing by √d_k keeps logits ≈ unit variance, softmax stays diffuse, gradients flow. Try it: at d_model = 4096 without scaling, training collapses in the first few hundred steps.

6. Multi-head attention

A single attention computes one routing pattern. Multi-head runs h heads in parallel with d_head = d_model / h:

MHA(x) = concat[head_1, …, head_h] · W_O
   head_i = attention(x · W_Q_i, x · W_K_i, x · W_V_i)

Typical numbers:

Modeld_modelhd_headLayers
GPT-2 small768126412
GPT-3 175B122889612896
Llama-3-8B40963212832
Llama-3-70B81926412880

Heads specialise. Mechanistic interpretability work has documented:

  • Previous-token heads — always attend to position i-1.
  • Induction heads — copy patterns: after seeing … A B … A, attend to that previous B. These are the engine of in-context learning.
  • Syntactic heads — subject-verb agreement, coreference, bracket matching.

You don't tell the model to have these — they emerge during training. That parallelism is the point of multi-head: many hypotheses, jointly trained.

7. Memory & compute cost

For batch B, layers L, heads h, head-dim d_head, sequence length T:

ComponentFLOPs (forward)Memory
Attention scores2 · B · L · h · T² · d_headB · L · h · T² (the matrix)
Attention · V2 · B · L · h · T² · d_head
MLP2 · B · L · T · d · 4dB · T · 4d

Two takeaways:

  1. MLPs dominate FLOPs for typical T < 4k (≈ 2/3 of compute).
  2. Attention dominates memory at long context. B=1, L=32, h=32, T=32k in fp16 = 64 GB just for attention matrices. This is why FlashAttention exists — tiles the computation, never materialises the full matrix.

8. From a single block to a full model

A decoder-only block is:

x = x + MHA(LayerNorm(x))     # attention sub-layer
x = x + MLP(LayerNorm(x))     # MLP sub-layer (Session 6)

Pre-norm style (LN inside the residual) trains stably for very deep stacks. The residual connection is essential — it lets gradients flow directly from logits back to the embedding through 96 layers.

9. Hands-on (last 30 min)

import torch, torch.nn as nn

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_head = d_model // n_heads
        self.n_heads = n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        B, T, D = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q, k, v = (t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
                   for t in (q, k, v))
        att = (q @ k.transpose(-2, -1)) / (self.d_head ** 0.5)
        mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
        att = att.masked_fill(mask == 0, float('-inf')).softmax(-1)
        out = (att @ v).transpose(1, 2).contiguous().view(B, T, D)
        return self.proj(out)

Run on (2, 16, 128) random input with d_model=128, n_heads=4. Then remove the / d_head ** 0.5 scaling and inspect att[0,0,5] — it'll collapse toward one-hot. That's the experiment.

10. What's next (Session 6)

  • Positional encoding — sinusoidal vs learned vs RoPE
  • The MLP — where most of the facts live
  • LayerNorm — pre-norm vs post-norm
  • KV cache — what actually grows in memory at inference

Reading material

In-depth research material

Video reference

▶︎ 3Blue1Brown — Attention in transformers, step by step

Pick a quiet 30 minutes during this session to actually watch it. Don't multitask.

LeetCode — Two Sum

  • Link: https://leetcode.com/problems/two-sum/
  • Difficulty: Easy
  • Why this problem: Hash-map for O(n) lookup; the canonical interview opener.
  • Time-box: 30 minutes. Look up the editorial only after.

Post-session checklist

By the end of this session you should be able to:

  • Draw the residual stream from token IDs → logits with N blocks and label every arrow.
  • Derive Q, K, V from the residual stream and explain what each projection learns.
  • State why scaled dot-product attention divides by √d_k and what breaks without it.
  • List 3 head specialisations documented by mechanistic interpretability.
  • Compute attention FLOPs and memory for B=1, L=32, h=32, d_head=128, T=4096.
  • Implement a causal self-attention block from scratch (≤ 30 lines).

Generated from sessions_data.py + content_part*.py. To edit a video / leetcode / title, edit the data file and re-run write_sessions.py.