makemore: from counting bigrams to a WaveNet
names.txt: 32,033 names, one per line. Vocabulary is 26 letters + . (start/end token) = 27 characters. Every name emma is wrapped to .emma. and the bigrams are (.,e), (e,m), (m,m), (m,a), (a,.). Goal: predict the next character.
Five models, each adding one mechanism. Loss is negative log likelihood, lower is better.
| Model | Mechanism | Val NLL |
|---|---|---|
| Bigram counts | 27×27 count matrix, +1 smoothing | 2.45 |
| Bigram NN | 27→27 logits, softmax, gradient descent | 2.46 |
| MLP (Bengio 2003) | 3-char context, 10-dim embedding, 200-hidden tanh | 2.10 |
| MLP + BN + Kaiming | same + proper init + batch norm | 2.05 |
| WaveNet-style | hierarchical pairwise fusion, 8-char context | 1.99 |
{"data":[{"x":["bigram counts","bigram NN","MLP","MLP + BN + Kaiming","WaveNet-style"],"y":[2.45,2.46,2.10,2.05,1.99],"type":"bar","marker":{"color":["#bbbbbb","#9b9bff","#5959ff","#3838ff","#EF553B"]},"text":[2.45,2.46,2.10,2.05,1.99],"textposition":"outside","hovertemplate":"%{x}<br>val NLL %{y}<extra></extra>"}],"layout":{"title":{"text":"Validation NLL across the five models"},"yaxis":{"title":"NLL (lower = better)","range":[0,3.0]},"xaxis":{"tickangle":-25},"height":420,"margin":{"l":60,"r":30,"t":60,"b":120},"showlegend":false}}
1. Counting bigrams
Build the count matrix directly:
N = torch.zeros((27, 27), dtype=torch.int32)
for w in words:
chs = ['.'] + list(w) + ['.']
for ch1, ch2 in zip(chs, chs[1:]):
N[stoi[ch1], stoi[ch2]] += 1
P = (N + 1).float()
P /= P.sum(dim=1, keepdim=True)
+1 smoothing avoids log(0) on bigrams that never appeared in training.
Sampling is torch.multinomial(P[ix], num_samples=1) in a loop until you draw the . token.
NLL = −mean(log P[bigram]) over the training set = 2.4543. As a sanity check: exp(−2.45) ≈ 8.7%, vs 1/27 ≈ 3.7% for uniform random. The bigram model assigns roughly 2.4× more probability to the correct next character than chance.
2. The same bigram model as a neural net
Same model, found by gradient descent instead of counting:
xenc = F.one_hot(xs, num_classes=27).float() # (N, 27)
W = torch.randn((27, 27), requires_grad=True)
logits = xenc @ W
counts = logits.exp()
probs = counts / counts.sum(1, keepdim=True)
loss = -probs[torch.arange(n), ys].log().mean()
xenc @ W is a row lookup (one-hot times matrix = pick a row of W). The “logits” are log-counts up to a constant. softmax(logits) matches the row-normalized count matrix. Trained 200 steps with lr=50, lands at NLL 2.4576 — within 0.01 of the count model.
The takeaway: this is the same model, just parameterized differently. The neural net’s W converges to the log of the count matrix. The equivalence breaks the moment you add nonlinearity or more context.
3. MLP, Bengio 2003
Bigrams are too local. With context .. you can’t tell e from o; with context ..em you can. Bump context from 1 → 3 characters.
input: 3 char indices, e.g. [0, 0, 5]
→ embedding C: (27, 10) → (3, 10)
→ concatenate → (30,)
→ Linear(30, 200) + tanh
→ Linear(200, 27) → logits
→ softmax → next-char distribution
Dataset built by sliding a 3-window over each name:
context target
['.', '.', '.'] 'e'
['.', '.', 'e'] 'm'
['.', 'e', 'm'] 'm'
['e', 'm', 'm'] 'a'
['m', 'm', 'a'] '.'
build_dataset() returns X (228146, 3) and Y (228146,). 80/10/10 train/dev/test split.
Forward:
emb = C[Xb] # (B, 3, 10)
h = torch.tanh(emb.view(-1, 30) @ W1 + b1) # (B, 200)
logits = h @ W2 + b2 # (B, 27)
loss = F.cross_entropy(logits, Yb)
emb.view(-1, 30) flattens the 3-char window into a 30-d vector. Same network sees position-dependent patterns because each character’s embedding occupies a different slice of the input.
Trains in ~30 sec. Val loss ~2.10. Sampled names start sounding like names: montelle, kymbry, madiet.
4. The three init bugs nobody tells you about
The MLP works, but if you instrument it, three things are quietly broken at step 0.
Initial loss is too high. Loss at random init is ~27 (exploded softmax). Expected value is −log(1/27) ≈ 3.3. Cause: W2 and b2 initialized from N(0, 1) produce logits with huge variance — softmax assigns near-1 probability to one random class, and if it’s not the right one, −log(tiny) ≈ huge. Fix: scale W2 down by ~0.01 and zero b2. Initial loss drops to 3.32.
Tanh saturation. Most pre-activations land outside [-2, 2] at init, where tanh is flat. Local gradient (1 − tanh²(x)) is near 0, gradients can’t flow through these neurons, and they’re effectively dead. Diagnose with (h.abs() > 0.99).float().mean() per neuron — at init this is >97% for some neurons. Fix: scale W1 so that (W1.T @ x) has variance ~1.
Eyeballing the scaling factor. Kaiming He’s paper gives the formula directly: for a layer with fan_in inputs and a tanh/relu nonlinearity, initialize weights from N(0, gain/sqrt(fan_in)) where gain = 5/3 for tanh, sqrt(2) for relu. PyTorch ships this as torch.nn.init.kaiming_normal_.
After Kaiming init: pre-activations stay in [-2, 2], no dead neurons, loss starts where it should. Val loss drops from 2.10 to ~2.07 just from fixing initialization.
5. BatchNorm: forcing the distribution post-hoc
Kaiming gets you into the right range at init. As you train, weights drift, distributions shift again. BatchNorm normalizes the pre-activation distribution every forward pass:
bnmeani = hpreact.mean(0, keepdim=True)
bnstdi = hpreact.std(0, keepdim=True)
hpreact = bngain * (hpreact - bnmeani) / bnstdi + bnbias
bngain and bnbias are learnable; they let the network un-do the normalization if it wants. In practice they stay small — the network mostly wants the normalized version.
The annoying part is inference. At inference there is no batch — you might be predicting one example at a time. So BatchNorm keeps an exponential moving average of train-time batch statistics and uses those at eval. Two extra non-learnable buffers per BN layer. Train/eval modes diverge.
This is also why model.eval() matters: without it, BatchNorm at inference would use the single-example statistics (variance = 0, division by zero, garbage output).
Val loss with init fixes + BN: ~2.05.
6. Manual backprop, every gradient by hand
For one block of training I deleted loss.backward() and computed every gradient by hand, layer by layer.
The cross-entropy case is the one worth writing out. Cross-entropy fuses three ops: softmax, pick the correct-class probability, −log. Differentiating directly:
For the correct class y:
p_y = exp(z_y) / S where S = Σ exp(z_j)
dL/dz_y = p_y − 1
For any other class i ≠ y:
dL/dz_i = p_i
So dlogits = probs.clone(); dlogits[range(n), y] -= 1; dlogits /= n. That’s it. The most common loss function in deep learning has a 4-line gradient.
Once you’ve done this, autograd stops being a black box. PyTorch is registering a _backward closure on each op, exactly like micrograd, then walking the DAG in reverse and applying these closed-form rules.
7. WaveNet-style hierarchical fusion
The MLP smashes all 8 characters into one vector and runs a single Linear over it. Every character has to interact with every other character in one shot.
WaveNet processes pairs of adjacent characters, then pairs of pairs, then pairs of those:
[c1 c2 c3 c4 c5 c6 c7 c8] 8 chars, 10-dim each
\_/ \_/ \_/ \_/
[b1 b2 b3 b4] 4 bigram reps
\___/ \___/
[q1 q2] 2 four-gram reps
\________/
[o1] 1 output → predict next char
Each fusion is the same operation: Linear((B, T/2, 2C) → (B, T/2, C)) + tanh. Local context builds up gradually.
Same dataset, same training loop. Val NLL: ~1.99.
Where val loss can keep dropping
| Add | Expected drop |
|---|---|
| Longer context (12, 16 chars) | small, diminishing |
| More embedding dims | small |
| Multi-head self-attention | substantial — bigrams → attention is the biggest single step |
| More data | this dataset is tiny |
Attention is what the tiny GPT post picks up.
Enjoy Reading This Article?
Here are some more articles you might like to read next: