<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en"><generator uri="https://jekyllrb.com/" version="4.4.1">Jekyll</generator><link href="https://debtirthasaha.github.io/feed.xml" rel="self" type="application/atom+xml"/><link href="https://debtirthasaha.github.io/" rel="alternate" type="text/html" hreflang="en"/><updated>2026-05-18T13:28:22+00:00</updated><id>https://debtirthasaha.github.io/feed.xml</id><title type="html">blank</title><subtitle>Engineering notes on building ML systems from scratch — transformers, tokenizers, GPT-2 reproduction, and what I learn along the way. </subtitle><entry><title type="html">Eight A100s, $61, and 124M parameters</title><link href="https://debtirthasaha.github.io/blog/2026/gpt2-124m/" rel="alternate" type="text/html" title="Eight A100s, $61, and 124M parameters"/><published>2026-05-17T18:00:00+00:00</published><updated>2026-05-17T18:00:00+00:00</updated><id>https://debtirthasaha.github.io/blog/2026/gpt2-124m</id><content type="html" xml:base="https://debtirthasaha.github.io/blog/2026/gpt2-124m/"><![CDATA[<p>End-to-end reproduction of Karpathy’s “Let’s reproduce GPT-2 (124M)” video. The build started on a 4 GB GTX 1650 and finished on 8× A100 SXM4 rented from Lambda. Training ran for 19,073 steps over 10B FineWeb-Edu tokens. Total cost: <strong>$61.23</strong>. Final val loss <strong>3.40</strong> vs OpenAI’s released baseline of 3.29 — a <strong>97% match</strong>. HellaSwag 26.99% vs 29.45% baseline.</p> <table> <thead> <tr> <th>Metric</th> <th>Start</th> <th>End</th> <th>OpenAI baseline</th> <th>% of target</th> </tr> </thead> <tbody> <tr> <td>Val loss</td> <td>10.951</td> <td><strong>3.3969</strong></td> <td>3.292</td> <td><strong>97%</strong></td> </tr> <tr> <td>HellaSwag acc_norm</td> <td>24.82%</td> <td><strong>26.99%</strong></td> <td>29.45%</td> <td>91%</td> </tr> <tr> <td>Sustained throughput</td> <td>—</td> <td>1.1M tokens/sec</td> <td>—</td> <td>—</td> </tr> <tr> <td>Steps/min</td> <td>—</td> <td>~127</td> <td>—</td> <td>—</td> </tr> </tbody> </table> <p>The val loss gap is small enough that another ~$80 of training would close it. The HellaSwag gap is the well-known FineWeb-Edu artifact — high-quality educational text gives slightly worse commonsense reasoning per perplexity unit than OpenAI’s WebText.</p> <p>The actual training run, with the OpenAI baseline as a reference line:</p> <pre><code class="language-plotly">{"data":[{"x":[0,50,100,150,200,250,300,350,400,450,500,550,600,650,700,750,800,850,900,950,1000,1050,1100,1150,1200,1250,1300,1350,1400,1450,1500,1550,1600,1650,1700,1750,1800,1850,1900,1950,2000,2050,2100,2150,2200,2250,2300,2350,2400,2450,2500,2550,2600,2650,2700,2750,2800,2850,2900,2950,3000,3050,3100,3150,3200,3250,3300,3350,3400,3450,3500,3550,3600,3650,3700,3750,3800,3850,3900,3950,4000,4050,4100,4150,4200,4250,4300,4350,4400,4450,4500,4550,4600,4650,4700,4750,4800,4850,4900,4950,5000,5050,5100,5150,5200,5250,5300,5350,5400,5450,5500,5550,5600,5650,5700,5750,5800,5850,5900,5950,6000,6050,6100,6150,6200,6250,6300,6350,6400,6450,6500,6550,6600,6650,6700,6750,6800,6850,6900,6950,7000,7050,7100,7150,7200,7250,7300,7350,7400,7450,7500,7550,7600,7650,7700,7750,7800,7850,7900,7950,8000,8050,8100,8150,8200,8250,8300,8350,8400,8450,8500,8550,8600,8650,8700,8750,8800,8850,8900,8950,9000,9050,9100,9150,9200,9250,9300,9350,9400,9450,9500,9550,9600,9650,9700,9750,9800,9850,9900,9950,10000,10050,10100,10150,10200,10250,10300,10350,10400,10450,10500,10550,10600,10650,10700,10750,10800,10850,10900,10950,11000,11050,11100,11150,11200,11250,11300,11350,11400,11450,11500,11550,11600,11650,11700,11750,11800,11850,11900,11950,12000,12050,12100,12150,12200,12250,12300,12350,12400,12450,12500,12550,12600,12650,12700,12750,12800,12850,12900,12950,13000,13050,13100,13150,13200,13250,13300,13350,13400,13450,13500,13550,13600,13650,13700,13750,13800,13850,13900,13950,14000,14050,14100,14150,14200,14250,14300,14350,14400,14450,14500,14550,14600,14650,14700,14750,14800,14850,14900,14950,15000,15050,15100,15150,15200,15250,15300,15350,15400,15450,15500,15550,15600,15650,15700,15750,15800,15850,15900,15950,16000,16050,16100,16150,16200,16250,16300,16350,16400,16450,16500,16550,16600,16650,16700,16750,16800,16850,16900,16950,17000,17050,17100,17150,17200,17250,17300,17350,17400,17450,17500,17550,17600,17650,17700,17750,17800,17850,17900,17950,18000,18050,18100,18150,18200,18250,18300,18350,18400,18450,18500,18550,18600,18650,18700,18750,18800,18850,18900,18950,19000,19050],"y":[10.955029,8.652573,7.339675,6.86025,6.542106,6.387939,6.357413,6.096048,5.967309,5.876346,5.71647,5.620126,5.435894,5.515324,5.390418,5.392198,5.161924,5.103988,4.995365,4.951733,4.926783,4.729811,4.587543,4.550589,4.48293,4.45039,4.537868,4.524739,4.583992,4.460568,4.444868,4.381177,4.368808,4.243211,4.285657,4.238783,4.14917,4.080226,4.098678,4.060075,4.298099,4.237405,4.191657,4.160752,4.198476,4.104283,4.108762,4.067921,4.116701,3.930878,3.955742,3.91789,3.876447,4.088891,4.056384,4.056029,4.073021,4.063436,4.069367,4.050061,4.010026,3.995684,3.911222,3.906269,3.913971,3.918218,3.8181,3.854215,3.778675,4.010101,3.959394,3.957729,3.982995,3.894079,3.907338,3.947609,3.883737,3.910084,3.870479,3.850236,3.727764,3.670568,3.701516,3.90143,3.958811,3.923477,3.937549,3.883154,3.868949,3.855658,3.828887,3.774447,3.781784,3.798658,3.635849,3.673353,3.645618,3.586813,3.888197,3.844004,3.799917,3.815917,3.847712,3.832865,3.779544,3.74312,3.729076,3.561608,3.449939,3.528541,3.683025,3.799625,3.81723,3.845233,3.772759,3.711059,3.710322,3.754577,3.779361,3.760361,3.777559,3.711643,3.720454,3.777155,3.65674,3.782415,3.646935,3.719334,3.650414,3.724901,3.751636,3.754004,3.742701,3.696117,3.698276,3.68034,3.759598,3.708417,3.695954,3.631462,3.729955,3.715655,3.698024,3.702647,3.629437,3.723177,3.659523,3.68377,3.655551,3.658967,3.7362,3.69976,3.615459,3.622903,3.649003,3.639076,3.662286,3.650483,3.670292,3.651448,3.695603,3.68487,3.654457,3.708802,3.614856,3.69809,3.693779,3.759378,3.655924,3.603687,3.635201,3.577073,3.665966,3.597528,3.628942,3.662407,3.59794,3.680565,3.593917,3.621378,3.630538,3.659089,3.641257,3.618741,3.630788,3.586698,3.624945,3.674439,3.589103,3.632992,3.573211,3.668649,3.576148,3.580356,3.653369,3.636115,3.52358,3.505952,3.577158,3.549963,3.615004,3.577269,3.568972,3.564436,3.570774,3.634584,3.541682,3.667766,3.649652,3.562427,3.635001,3.523956,3.58213,3.502887,3.647513,3.51616,3.524837,3.57378,3.494234,3.525831,3.57591,3.551852,3.528437,3.449448,3.544546,3.442327,3.526242,3.499907,3.521501,3.529324,3.462544,3.582874,3.455339,3.536372,3.499629,3.504132,3.533975,3.490507,3.533957,3.489776,3.557726,3.498231,3.519463,3.488743,3.457328,3.46037,3.476739,3.496565,3.492133,3.472932,3.479477,3.561807,3.48808,3.509819,3.694302,3.460399,3.524822,3.494879,3.392873,3.560445,3.419302,3.431916,3.512982,3.439761,3.548496,3.439065,3.460358,3.483511,3.41854,3.472522,3.46191,3.484112,3.439632,3.421415,3.482553,3.374554,3.473166,3.405603,3.465045,3.494958,3.42211,3.420622,3.453494,3.398499,3.427082,3.472697,3.431491,3.40042,3.439947,3.439783,3.438387,3.458462,3.404739,3.435358,3.447509,3.449286,3.387485,3.45961,3.476094,3.443756,3.433138,3.329407,3.445033,3.35931,3.425076,3.378695,3.468765,3.427226,3.371132,3.437642,3.390105,3.415966,3.414496,3.380322,3.463654,3.346844,3.372707,3.376734,3.404138,3.317891,3.375628,3.411178,3.285608,3.410325,3.337027,3.41946,3.418103,3.326298,3.391832,3.404434,3.378114,3.332703,3.339972,3.366685,3.418615,3.384357,3.433587,3.384535,3.312392,3.38399,3.330543,3.428728,3.399785,3.389828,3.414865,3.348916,3.35787,3.314832,3.453763,3.406743,3.298486,3.358698,3.260078,3.487381,3.320862,3.405267,3.411797,3.287154,3.387619,3.278841,3.413348,3.371803,3.388267,3.352105,3.261056,3.370182,3.351461,3.389376,3.346794,3.29316,3.40975,3.297334,3.394829,3.381837,3.26427,3.313486,3.348226,3.354343,3.289057,3.260655,3.272449,3.453633],"name":"train loss","type":"scatter","mode":"lines","line":{"color":"rgba(99,110,250,0.35)","width":1},"hovertemplate":"step %{x}&lt;br&gt;train %{y:.3f}&lt;extra&gt;&lt;/extra&gt;"},{"x":[0,250,500,750,1000,1250,1500,1750,2000,2250,2500,2750,3000,3250,3500,3750,4000,4250,4500,4750,5000,5250,5500,5750,6000,6250,6500,6750,7000,7250,7500,7750,8000,8250,8500,8750,9000,9250,9500,9750,10000,10250,10500,10750,11000,11250,11500,11750,12000,12250,12500,12750,13000,13250,13500,13750,14000,14250,14500,14750,15000,15250,15500,15750,16000,16250,16500,16750,17000,17250,17500,17750,18000,18250,18500,18750,19000,19072],"y":[10.9512,6.43,5.824,5.3065,4.8817,4.5899,4.4297,4.3199,4.2401,4.171,4.1257,4.0726,4.0291,3.9984,3.9699,3.9341,3.911,3.8882,3.8702,3.8511,3.83,3.813,3.8044,3.782,3.7661,3.7459,3.7313,3.7218,3.7073,3.6976,3.6843,3.6784,3.6638,3.6543,3.6425,3.6341,3.6258,3.616,3.607,3.5969,3.5895,3.5786,3.5725,3.5648,3.5544,3.5483,3.5396,3.5319,3.5236,3.5161,3.5099,3.5031,3.4963,3.4897,3.485,3.4792,3.4719,3.4666,3.4611,3.4558,3.4511,3.4459,3.4418,3.4376,3.4329,3.4287,3.4259,3.4222,3.4196,3.4169,3.4137,3.4108,3.4082,3.406,3.4044,3.4023,3.3994,3.3969],"name":"val loss","type":"scatter","mode":"lines+markers","line":{"color":"#EF553B","width":2.5},"marker":{"size":4},"hovertemplate":"step %{x}&lt;br&gt;val %{y:.3f}&lt;extra&gt;&lt;/extra&gt;"},{"x":[0,19072],"y":[3.292,3.292],"name":"OpenAI GPT-2 124M val baseline (3.292)","type":"scatter","mode":"lines","line":{"color":"#00cc96","width":1.5,"dash":"dash"},"hoverinfo":"skip"}],"layout":{"title":{"text":"Training and validation loss, 19,073 steps on FineWeb-Edu"},"xaxis":{"title":"step","range":[-100,19200]},"yaxis":{"title":"loss"},"height":460,"margin":{"l":60,"r":30,"t":60,"b":50},"hovermode":"x unified","legend":{"x":0.6,"y":0.95}}}
</code></pre> <h2 id="model-matching-hfs-gpt-2-byte-for-byte">Model: matching HF’s GPT-2 byte-for-byte</h2> <p>The model class mirrors HuggingFace’s <code class="language-plaintext highlighter-rouge">GPT2LMHeadModel</code> parameter names exactly so that <code class="language-plaintext highlighter-rouge">from_pretrained("gpt2")</code> can copy weights in. 148 tensors per HF GPT-2 state_dict, in three groups:</p> <ul> <li><strong>Embeddings.</strong> <code class="language-plaintext highlighter-rouge">transformer.wte.weight (50257, 768)</code>, <code class="language-plaintext highlighter-rouge">transformer.wpe.weight (1024, 768)</code>. Position embeddings are <em>learned</em>, not sinusoidal — that was a 2018 GPT-2 choice.</li> <li><strong>12 identical blocks.</strong> Each has pre-norm <code class="language-plaintext highlighter-rouge">ln_1</code>, fused QKV projection <code class="language-plaintext highlighter-rouge">attn.c_attn.weight (768, 2304)</code>, output projection <code class="language-plaintext highlighter-rouge">attn.c_proj</code>, pre-norm <code class="language-plaintext highlighter-rouge">ln_2</code>, MLP up <code class="language-plaintext highlighter-rouge">mlp.c_fc.weight (768, 3072)</code> (4× expansion), MLP down <code class="language-plaintext highlighter-rouge">mlp.c_proj</code>. GELU between the MLP linears.</li> <li><strong>Final.</strong> <code class="language-plaintext highlighter-rouge">transformer.ln_f</code> (the extra LayerNorm GPT-2 added vs the original 2017 transformer) and <code class="language-plaintext highlighter-rouge">lm_head.weight (50257, 768)</code>.</li> </ul> <p>Two implementation gotchas:</p> <p><strong>HF weights are stored transposed.</strong> HuggingFace uses a TF-legacy <code class="language-plaintext highlighter-rouge">Conv1D</code> layer that stores <code class="language-plaintext highlighter-rouge">(in, out)</code>-shaped weights. Standard PyTorch <code class="language-plaintext highlighter-rouge">nn.Linear</code> stores <code class="language-plaintext highlighter-rouge">(out, in)</code>. So when copying HF weights into our <code class="language-plaintext highlighter-rouge">nn.Linear</code>-based model, the four matrices <code class="language-plaintext highlighter-rouge">c_attn</code>, <code class="language-plaintext highlighter-rouge">c_proj</code> (attn), <code class="language-plaintext highlighter-rouge">c_fc</code>, <code class="language-plaintext highlighter-rouge">c_proj</code> (mlp) need <code class="language-plaintext highlighter-rouge">.t()</code>. The other tensors copy directly.</p> <p><strong>Weight tying.</strong> <code class="language-plaintext highlighter-rouge">lm_head.weight</code> and <code class="language-plaintext highlighter-rouge">transformer.wte.weight</code> are the <em>same tensor</em> in GPT-2 — one physical 38.6M-parameter matrix used both as input embedding and output classifier. That matrix alone is ~30% of the full 124M. Implementation: <code class="language-plaintext highlighter-rouge">self.transformer.wte.weight = self.lm_head.weight</code>. Two Python names, one tensor object.</p> <h2 id="residual-stream-init">Residual stream init</h2> <p>GPT-2 scales the output projection of each residual sub-layer at init by <code class="language-plaintext highlighter-rouge">1/sqrt(2*n_layer)</code>. With <code class="language-plaintext highlighter-rouge">n_layer = 12</code>, that’s a factor of ~0.204. Why: at each residual addition <code class="language-plaintext highlighter-rouge">x = x + sub_layer(x)</code>, the variance of <code class="language-plaintext highlighter-rouge">x</code> grows. Without rescaling, variance compounds across 24 sub-layers (12 attn + 12 MLP) and the residual stream’s scale explodes by init time. The rescale keeps post-stack variance close to the input variance.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">pn</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">self</span><span class="p">.</span><span class="nf">named_parameters</span><span class="p">():</span>
    <span class="k">if</span> <span class="n">pn</span><span class="p">.</span><span class="nf">endswith</span><span class="p">(</span><span class="sh">'</span><span class="s">c_proj.weight</span><span class="sh">'</span><span class="p">):</span>
        <span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">init</span><span class="p">.</span><span class="nf">normal_</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span> <span class="o">*</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">config</span><span class="p">.</span><span class="n">n_layer</span><span class="p">)</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">)</span>
</code></pre></div></div> <p>Everything else is standard <code class="language-plaintext highlighter-rouge">N(0, 0.02)</code>. This trick is from nanoGPT, not the GPT-2 paper, but it makes training much more stable.</p> <h2 id="data-fineweb-edu-sharded">Data: FineWeb-Edu, sharded</h2> <p>10B tokens of FineWeb-Edu (the educational subset), tokenized with tiktoken’s <code class="language-plaintext highlighter-rouge">gpt2</code> encoding, written out as 100 shards of <code class="language-plaintext highlighter-rouge">(100M tokens, uint16)</code> <code class="language-plaintext highlighter-rouge">.npy</code> files. Tokenization is multiprocess and dominated by Python overhead — ~30 minutes on the 8× A100 instance, ~$26 just to tokenize.</p> <p><code class="language-plaintext highlighter-rouge">uint16</code> because vocab size is 50257, comfortably under 65536. Cuts disk size in half.</p> <p><code class="language-plaintext highlighter-rouge">DataLoaderLite</code> keeps one shard in memory at a time, advances to the next when exhausted, and respects the DDP rank/world-size partitioning:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
    <span class="n">self</span><span class="p">.</span><span class="n">current_shard</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">self</span><span class="p">.</span><span class="n">tokens</span> <span class="o">=</span> <span class="nf">load_tokens</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">shards</span><span class="p">[</span><span class="n">self</span><span class="p">.</span><span class="n">current_shard</span><span class="p">])</span>
    <span class="n">self</span><span class="p">.</span><span class="n">current_position</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">B</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">T</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">process_rank</span>

<span class="k">def</span> <span class="nf">next_batch</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
    <span class="n">B</span><span class="p">,</span> <span class="n">T</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">B</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">T</span>
    <span class="n">buf</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">tokens</span><span class="p">[</span><span class="n">self</span><span class="p">.</span><span class="n">current_position</span> <span class="p">:</span> <span class="n">self</span><span class="p">.</span><span class="n">current_position</span> <span class="o">+</span> <span class="n">B</span><span class="o">*</span><span class="n">T</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">buf</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">].</span><span class="nf">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">)</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">buf</span><span class="p">[</span><span class="mi">1</span><span class="p">:].</span><span class="nf">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">)</span>
    <span class="n">self</span><span class="p">.</span><span class="n">current_position</span> <span class="o">+=</span> <span class="n">B</span> <span class="o">*</span> <span class="n">T</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">num_processes</span>
    <span class="k">if</span> <span class="n">self</span><span class="p">.</span><span class="n">current_position</span> <span class="o">+</span> <span class="p">(</span><span class="n">B</span> <span class="o">*</span> <span class="n">T</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">num_processes</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">&gt;</span> <span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">tokens</span><span class="p">):</span>
        <span class="n">self</span><span class="p">.</span><span class="n">current_shard</span> <span class="o">=</span> <span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">current_shard</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">shards</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">tokens</span> <span class="o">=</span> <span class="nf">load_tokens</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">shards</span><span class="p">[</span><span class="n">self</span><span class="p">.</span><span class="n">current_shard</span><span class="p">])</span>
        <span class="n">self</span><span class="p">.</span><span class="n">current_position</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">B</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">T</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">process_rank</span>
    <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">current_position</code> starts at <code class="language-plaintext highlighter-rouge">B * T * process_rank</code> so each rank reads a different stride. Combined with <code class="language-plaintext highlighter-rouge">current_position += B * T * num_processes</code> after every step, the 8 ranks tile the shard without overlap.</p> <h2 id="hyperparameters-every-number-traced-to-a-source">Hyperparameters: every number traced to a source</h2> <p><strong>The single most important insight: papers specify training in tokens, not steps.</strong> Everything follows from that.</p> <table> <thead> <tr> <th>Knob</th> <th>Value</th> <th>Source</th> </tr> </thead> <tbody> <tr> <td>Total batch (tokens / step)</td> <td>524288 = 2¹⁹ ≈ 0.5M</td> <td>GPT-3 paper Table 2.1, “GPT-3 Small”</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">max_steps</code></td> <td>19073</td> <td><code class="language-plaintext highlighter-rouge">10_000_000_000 / 524288</code> = full pass over 10B tokens</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">warmup_steps</code></td> <td>715</td> <td><code class="language-plaintext highlighter-rouge">375_000_000 / 524288</code> (GPT-3 §2.3: linear warmup over 375M tokens)</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">max_lr</code></td> <td>6e-4</td> <td>GPT-3 paper Table 2.1</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">min_lr</code></td> <td><code class="language-plaintext highlighter-rouge">max_lr * 0.1</code></td> <td>GPT-3 §2.3: “cosine decay to 10% of max”</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">betas</code></td> <td>(0.9, 0.95)</td> <td>GPT-3 §2.3 — note β₂=0.95, not Adam default 0.999</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">weight_decay</code></td> <td>0.1</td> <td>GPT-3 §2.3</td> </tr> <tr> <td>Wd applies to 2D+ params only</td> <td>—</td> <td>GPT-3 §2.3</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">clip_grad_norm</code></td> <td>1.0</td> <td>GPT-3 §2.3</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">n_layer, n_head, n_embd</code></td> <td>12, 12, 768</td> <td>GPT-2 paper Table 2 (124M row)</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">block_size</code></td> <td>1024</td> <td>GPT-2 paper</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">vocab_size</code> (padded)</td> <td>50304</td> <td>Nearest multiple of 128 ≥ 50257, for tensor-core tile alignment. nanoGPT addition, not a paper.</td> </tr> </tbody> </table> <p>The provenance layers cleanly:</p> <ul> <li><strong>GPT-3 paper</strong> → all training hyperparameters (LR, betas, weight decay, batch, schedule)</li> <li><strong>GPT-2 paper</strong> → architecture (layers, heads, embd, pre-norm, ln_f, GELU)</li> <li><strong>nanoGPT</strong> → implementation tricks (vocab pad to 50304, residual init rescale, fused AdamW)</li> </ul> <h2 id="gradient-accumulation-and-the-global-batch-math">Gradient accumulation and the global-batch math</h2> <p>Global batch size is fixed at 524288 tokens. With 8 GPUs and per-GPU <code class="language-plaintext highlighter-rouge">B=16, T=1024</code>:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>per_gpu_tokens     = B * T              = 16 * 1024  = 16384
global_per_step    = per_gpu * world    = 16384 * 8  = 131072
grad_accum_steps   = global / 524288    = 524288 / 131072 = 4
</code></pre></div></div> <p>Every “macro step” = 4 forward+backward passes per GPU + 1 optimizer step + 1 all-reduce. Compute is dominated by the forward+backward; the all-reduce is bandwidth-bound but small compared to compute.</p> <p>In code:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">micro_step</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">grad_accum_steps</span><span class="p">):</span>
    <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">train_loader</span><span class="p">.</span><span class="nf">next_batch</span><span class="p">()</span>
    <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="nf">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">y</span><span class="p">.</span><span class="nf">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="nf">autocast</span><span class="p">(</span><span class="n">device_type</span><span class="o">=</span><span class="sh">'</span><span class="s">cuda</span><span class="sh">'</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">bfloat16</span><span class="p">):</span>
        <span class="n">logits</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="nf">model</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">/</span> <span class="n">grad_accum_steps</span>     <span class="c1"># critical
</span>    <span class="k">if</span> <span class="n">ddp</span><span class="p">:</span>
        <span class="n">model</span><span class="p">.</span><span class="n">require_backward_grad_sync</span> <span class="o">=</span> <span class="p">(</span><span class="n">micro_step</span> <span class="o">==</span> <span class="n">grad_accum_steps</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">loss</span><span class="p">.</span><span class="nf">backward</span><span class="p">()</span>
</code></pre></div></div> <p>Two things to be careful about:</p> <ul> <li><strong><code class="language-plaintext highlighter-rouge">loss / grad_accum_steps</code></strong> because <code class="language-plaintext highlighter-rouge">loss.backward()</code> accumulates gradients with <code class="language-plaintext highlighter-rouge">+=</code>. Without dividing, four micro-steps give 4× the gradient — equivalent to multiplying the LR by 4. Subtle bug.</li> <li><strong><code class="language-plaintext highlighter-rouge">require_backward_grad_sync</code></strong> is DDP’s escape hatch: gradient all-reduce normally fires on every <code class="language-plaintext highlighter-rouge">.backward()</code> call. With grad accumulation, we only want to sync on the <em>last</em> micro-step. The intermediate all-reduces would just waste bandwidth syncing partial gradients that will be added to again before the step. Setting the flag to <code class="language-plaintext highlighter-rouge">False</code> for the first 3 micro-steps and <code class="language-plaintext highlighter-rouge">True</code> only on the 4th cuts ~25% of network traffic with no correctness cost.</li> </ul> <h2 id="speed-bf16-sdpa-tf32-vocab-padding">Speed: bf16, SDPA, TF32, vocab padding</h2> <p>On the 4 GB 1650 baseline at <code class="language-plaintext highlighter-rouge">B=4, T=32</code>: ~1080 tokens/sec at fp32. The Ampere/Hopper speedups don’t fit on that card. The cloud instance runs Hopper-tier compute and reclaims them all:</p> <ul> <li><strong>bfloat16 autocast</strong> for forward+backward. Halves memory bandwidth and unlocks tensor-core throughput. bf16 has fp32’s exponent range, so unlike fp16 you don’t need a gradient scaler — the simple <code class="language-plaintext highlighter-rouge">torch.autocast(dtype=torch.bfloat16)</code> context manager works directly.</li> <li><strong>Scaled-dot-product attention</strong> (<code class="language-plaintext highlighter-rouge">F.scaled_dot_product_attention</code>) replaces the manual Q@K.T softmax. On Ampere+ this dispatches to FlashAttention 2 under the hood — fused kernel, no materialized <code class="language-plaintext highlighter-rouge">(T, T)</code> attention matrix, much better memory traffic. On the 1650 SDPA falls back to manual and gives ~0% speedup; on A100 it’s a real win.</li> <li><strong>TF32 matmuls</strong> via <code class="language-plaintext highlighter-rouge">torch.set_float32_matmul_precision('high')</code>. Same fp32 storage but tensor-core compute. Free speedup, no accuracy cost on any benchmark I checked.</li> <li><strong>vocab_size padded 50257 → 50304.</strong> Multiple of 128 = aligns with tensor-core tile sizes. The extra rows are dead weight that never see gradient because no real token ID maps to them, but they make the matmul faster. ~150 KB of extra parameters, several % throughput. Worth it.</li> <li><strong><code class="language-plaintext highlighter-rouge">torch.compile</code></strong> off for this run because HellaSwag eval has variable shapes (each example has its own <code class="language-plaintext highlighter-rouge">max_len</code>) and compile recompiles per shape — 10042 recompiles is catastrophic. Without HellaSwag in the loop, flip compile back on and reclaim ~20-30%.</li> </ul> <p>DDP wraps <em>after</em> compile (when compile is on):</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="nc">GPT</span><span class="p">(</span><span class="nc">GPTConfig</span><span class="p">(</span><span class="n">vocab_size</span><span class="o">=</span><span class="mi">50304</span><span class="p">))</span>
<span class="n">model</span><span class="p">.</span><span class="nf">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_compile</span><span class="p">:</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">compile</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<span class="k">if</span> <span class="n">ddp</span><span class="p">:</span>
    <span class="n">model</span> <span class="o">=</span> <span class="nc">DDP</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">device_ids</span><span class="o">=</span><span class="p">[</span><span class="n">ddp_local_rank</span><span class="p">])</span>
<span class="n">raw_model</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">module</span> <span class="k">if</span> <span class="n">ddp</span> <span class="k">else</span> <span class="n">model</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">torch.compile(DDP(model))</code> would try to trace DDP’s gradient-sync wrappers, which isn’t real compute. Compile-first, DDP-second.</p> <h2 id="the-optimizer-weight-decay-split">The optimizer: weight decay split</h2> <p>Following GPT-3 §2.3, weight decay applies to “weight” matrices (2D and higher) but not to biases or LayerNorm γ/β (1D parameters). Implementation:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">configure_optimizers</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="p">,</span> <span class="n">device_type</span><span class="p">):</span>
    <span class="n">param_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">pn</span><span class="p">:</span> <span class="n">p</span> <span class="k">for</span> <span class="n">pn</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">self</span><span class="p">.</span><span class="nf">named_parameters</span><span class="p">()</span> <span class="k">if</span> <span class="n">p</span><span class="p">.</span><span class="n">requires_grad</span><span class="p">}</span>
    <span class="n">decay_params</span>   <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">param_dict</span><span class="p">.</span><span class="nf">items</span><span class="p">()</span> <span class="k">if</span> <span class="n">p</span><span class="p">.</span><span class="nf">dim</span><span class="p">()</span> <span class="o">&gt;=</span> <span class="mi">2</span><span class="p">]</span>
    <span class="n">nodecay_params</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">param_dict</span><span class="p">.</span><span class="nf">items</span><span class="p">()</span> <span class="k">if</span> <span class="n">p</span><span class="p">.</span><span class="nf">dim</span><span class="p">()</span> <span class="o">&lt;</span>  <span class="mi">2</span><span class="p">]</span>
    <span class="n">optim_groups</span> <span class="o">=</span> <span class="p">[</span>
        <span class="p">{</span><span class="sh">'</span><span class="s">params</span><span class="sh">'</span><span class="p">:</span> <span class="n">decay_params</span><span class="p">,</span>   <span class="sh">'</span><span class="s">weight_decay</span><span class="sh">'</span><span class="p">:</span> <span class="n">weight_decay</span><span class="p">},</span>
        <span class="p">{</span><span class="sh">'</span><span class="s">params</span><span class="sh">'</span><span class="p">:</span> <span class="n">nodecay_params</span><span class="p">,</span> <span class="sh">'</span><span class="s">weight_decay</span><span class="sh">'</span><span class="p">:</span> <span class="mf">0.0</span><span class="p">},</span>
    <span class="p">]</span>
    <span class="n">use_fused</span> <span class="o">=</span> <span class="n">device_type</span> <span class="o">==</span> <span class="sh">'</span><span class="s">cuda</span><span class="sh">'</span>
    <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="nc">AdamW</span><span class="p">(</span><span class="n">optim_groups</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="n">betas</span><span class="p">,</span>
                             <span class="n">eps</span><span class="o">=</span><span class="mf">1e-8</span><span class="p">,</span> <span class="n">fused</span><span class="o">=</span><span class="n">use_fused</span><span class="p">)</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">fused=True</code> calls a single CUDA kernel for the AdamW update instead of dispatching per-parameter. Trivially faster, no accuracy impact.</p> <h2 id="lr-schedule-warmup--cosine--floor">LR schedule: warmup + cosine + floor</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_lr</span><span class="p">(</span><span class="n">it</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">it</span> <span class="o">&lt;</span> <span class="n">warmup_steps</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">max_lr</span> <span class="o">*</span> <span class="p">(</span><span class="n">it</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">warmup_steps</span>
    <span class="k">if</span> <span class="n">it</span> <span class="o">&gt;</span> <span class="n">max_steps</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">min_lr</span>
    <span class="n">decay_ratio</span> <span class="o">=</span> <span class="p">(</span><span class="n">it</span> <span class="o">-</span> <span class="n">warmup_steps</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">max_steps</span> <span class="o">-</span> <span class="n">warmup_steps</span><span class="p">)</span>
    <span class="n">coeff</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">math</span><span class="p">.</span><span class="nf">cos</span><span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="n">pi</span> <span class="o">*</span> <span class="n">decay_ratio</span><span class="p">))</span>
    <span class="k">return</span> <span class="n">min_lr</span> <span class="o">+</span> <span class="n">coeff</span> <span class="o">*</span> <span class="p">(</span><span class="n">max_lr</span> <span class="o">-</span> <span class="n">min_lr</span><span class="p">)</span>
</code></pre></div></div> <p>Linear ramp from 0 to <code class="language-plaintext highlighter-rouge">max_lr</code> over <code class="language-plaintext highlighter-rouge">warmup_steps</code>. Then cosine decay to <code class="language-plaintext highlighter-rouge">min_lr = 0.1 * max_lr</code>. The cosine starts at <code class="language-plaintext highlighter-rouge">coeff=1</code> (full <code class="language-plaintext highlighter-rouge">max_lr</code>) and ends at <code class="language-plaintext highlighter-rouge">coeff=0</code> (full <code class="language-plaintext highlighter-rouge">min_lr</code>). At step 7240, <code class="language-plaintext highlighter-rouge">decay_ratio = (7240 - 715) / (19073 - 715) ≈ 0.355</code> and the LR is roughly <code class="language-plaintext highlighter-rouge">0.1 + 0.85 * cos(π·0.355) ≈ 0.5 * max_lr</code>. Mid-training, half max.</p> <h2 id="hellaswag-inline-eval">HellaSwag inline eval</h2> <p>HellaSwag is a 4-way multiple-choice commonsense benchmark. Each example: a context + 4 candidate endings, pick the most plausible. Scoring: for each candidate, average the LM’s per-token loss over just the ending tokens. The candidate with lowest loss is the model’s pick.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_most_likely_row</span><span class="p">(</span><span class="n">tokens</span><span class="p">,</span> <span class="n">mask</span><span class="p">,</span> <span class="n">logits</span><span class="p">):</span>
    <span class="n">shift_logits</span> <span class="o">=</span> <span class="n">logits</span><span class="p">[...,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:].</span><span class="nf">contiguous</span><span class="p">()</span>
    <span class="n">shift_tokens</span> <span class="o">=</span> <span class="n">tokens</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:].</span><span class="nf">contiguous</span><span class="p">()</span>
    <span class="n">shift_losses</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="nf">cross_entropy</span><span class="p">(</span>
        <span class="n">shift_logits</span><span class="p">.</span><span class="nf">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">shift_logits</span><span class="p">.</span><span class="nf">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)),</span>
        <span class="n">shift_tokens</span><span class="p">.</span><span class="nf">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span>
        <span class="n">reduction</span><span class="o">=</span><span class="sh">'</span><span class="s">none</span><span class="sh">'</span><span class="p">,</span>
    <span class="p">).</span><span class="nf">view</span><span class="p">(</span><span class="n">tokens</span><span class="p">.</span><span class="nf">size</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">shift_mask</span> <span class="o">=</span> <span class="n">mask</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:].</span><span class="nf">contiguous</span><span class="p">()</span>
    <span class="n">masked_losses</span> <span class="o">=</span> <span class="n">shift_losses</span> <span class="o">*</span> <span class="n">shift_mask</span>
    <span class="n">avg_loss</span> <span class="o">=</span> <span class="n">masked_losses</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">shift_mask</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">avg_loss</span><span class="p">.</span><span class="nf">argmin</span><span class="p">().</span><span class="nf">item</span><span class="p">()</span>
</code></pre></div></div> <p>10042 examples, sharded across the 8 ranks. The accuracies all-reduce at the end of each eval block.</p> <p>Random chance = 25%. OpenAI’s released GPT-2 (124M) baseline = 29.45%. Our reproduction = 26.99%.</p> <h2 id="training-memory-optimizer-state-is-most-of-it">Training memory: optimizer state is most of it</h2> <table> <thead> <tr> <th>What lives in GPU memory</th> <th>Training</th> <th>Inference</th> </tr> </thead> <tbody> <tr> <td>Model weights (124M × 4B)</td> <td>~500 MB</td> <td>~500 MB</td> </tr> <tr> <td>Gradients (one per param)</td> <td>~500 MB</td> <td>—</td> </tr> <tr> <td>Adam <code class="language-plaintext highlighter-rouge">m</code> running mean</td> <td>~500 MB</td> <td>—</td> </tr> <tr> <td>Adam <code class="language-plaintext highlighter-rouge">v</code> running variance</td> <td>~500 MB</td> <td>—</td> </tr> <tr> <td>Activations for backward</td> <td>huge, ∝ B×T</td> <td>—</td> </tr> <tr> <td>Forward activation buffers</td> <td>yes</td> <td>yes (ephemeral)</td> </tr> <tr> <td><strong>Total</strong></td> <td><strong>2-4 GB minimum</strong></td> <td><strong>~600 MB</strong></td> </tr> </tbody> </table> <p>The optimizer state is what kills training on small cards. A 4 GB GTX 1650 can comfortably <em>infer</em> 124M at fp16 but can’t comfortably <em>train</em> it. AdamW needs 3× the model size in extra state, before you’ve allocated a single activation.</p> <h2 id="cost-ledger">Cost ledger</h2> <table> <thead> <tr> <th>Phase</th> <th>Time</th> <th>Cost</th> </tr> </thead> <tbody> <tr> <td>FineWeb-Edu tokenization (multiprocess on 8× A100 inst.)</td> <td>~30 min</td> <td>~$8</td> </tr> <tr> <td>Training: 19,073 steps × ~472 ms/step</td> <td>~2.5 hours</td> <td>~$40</td> </tr> <tr> <td>Setup, idle, HF download throttles, debugging</td> <td>~1 hour</td> <td>~$13</td> </tr> <tr> <td><strong>Total</strong></td> <td><strong>~4 hours</strong></td> <td><strong>$61.23</strong></td> </tr> </tbody> </table> <p>8× A100 SXM4 on Lambda was $15.92/hr at the time. Every minute typing slowly costs real money — have the next command ready before SSH’ing in.</p> <p><code class="language-plaintext highlighter-rouge">tmux new -s train</code> keeps the training run alive across SSH drops. Doesn’t save money, but the alternative is a $40 training run dying because your laptop went to sleep.</p> <h2 id="what-the-trained-model-produces">What the trained model produces</h2> <p>Sample completions at val_loss 3.40 with <code class="language-plaintext highlighter-rouge">temperature=1.0, top_k=50</code>, prompt <code class="language-plaintext highlighter-rouge">"Hello, I'm a language model,"</code>:</p> <ol> <li>Coherent prose about teachers, students, classrooms. Grammatical for 3 sentences.</li> <li>Rambling but grammatical paragraph on social media and jobs.</li> <li><strong>Repetition loop</strong>: <code class="language-plaintext highlighter-rouge">"I'm a language model. I'm a language model. I'm a language model."</code></li> <li>Has dialogue with quote marks, sentence-level coherence.</li> </ol> <p>Diagnosis: working LM with sentence-level coherence and occasional repetition loops — exactly the failure mode of a 124M model at val_loss 3.4. At 3.1 the loops fade. With <code class="language-plaintext highlighter-rouge">temperature=0.7, top_k=20</code> the output is cleaner but less creative.</p> <h2 id="what-the-model-actually-is">What the model actually is</h2> <p>124,475,904 floating-point numbers plus ~200 lines of Python that combines them. The checkpoint file is those numbers plus a kilobyte of config. Inference loads the numbers and runs the architecture forward.</p> <table> <thead> <tr> <th>Tensor</th> <th>Shape</th> <th>Count</th> </tr> </thead> <tbody> <tr> <td><code class="language-plaintext highlighter-rouge">wte.weight</code> (tied with <code class="language-plaintext highlighter-rouge">lm_head</code>)</td> <td>50304 × 768</td> <td>38,633,472</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">wpe.weight</code></td> <td>1024 × 768</td> <td>786,432</td> </tr> <tr> <td>12 × Block</td> <td>each ~7.09M</td> <td>~85,054,464</td> </tr> <tr> <td><code class="language-plaintext highlighter-rouge">ln_f</code> (γ + β)</td> <td>768 × 2</td> <td>1,536</td> </tr> <tr> <td><strong>Total</strong></td> <td> </td> <td><strong>~124,475,904</strong></td> </tr> </tbody> </table> <p>Random init and trained model have the same 124M numbers in identical shapes. The only difference is the numerical values. Training is 19,073 nudges of <code class="language-plaintext highlighter-rouge">lr × gradient × (-1)</code> applied to every number. No single number means anything. Meaning emerges from the collective behavior of all of them running through the matmuls.</p> <h2 id="what-would-close-the-remaining-gap">What would close the remaining gap</h2> <ul> <li><strong>Train 5K more steps on 8× A100</strong> to hit OpenAI val loss exactly. ~$80, gets val loss to ~3.29.</li> <li><strong>Switch base dataset from FineWeb-Edu to OpenWebText.</strong> Educational text is high-quality but narrow. WebText is broader and gives better commonsense — closes the HellaSwag gap.</li> <li><strong>Quantize for inference.</strong> int8 or int4 brings the inference footprint from 600 MB to ~150 MB, runs much faster on the 1650.</li> </ul> <p>Code and plots: <a href="https://github.com/debtirthasaha/gpt2-124m-reproduction">github.com/debtirthasaha/gpt2-124m-reproduction</a>. The trained checkpoint (523 MB) is on Hugging Face at <a href="https://huggingface.co/MR0b0t/gpt2-124m-reproduction">MR0b0t/gpt2-124m-reproduction</a>.</p>]]></content><author><name></name></author><category term="deep-learning"/><category term="gpt-2"/><category term="reproduction"/><category term="training"/><category term="ddp"/><summary type="html"><![CDATA[Full reproduction of GPT-2 124M on rented multi-GPU hardware. Val loss 3.40 vs OpenAI's 3.29 (97% match), HellaSwag 27% vs 29.45%, in 2.5 hours of training.]]></summary></entry><entry><title type="html">BPE from scratch, and why your LLM can’t count L’s</title><link href="https://debtirthasaha.github.io/blog/2026/bpe-tokenizer/" rel="alternate" type="text/html" title="BPE from scratch, and why your LLM can’t count L’s"/><published>2026-04-25T10:00:00+00:00</published><updated>2026-04-25T10:00:00+00:00</updated><id>https://debtirthasaha.github.io/blog/2026/bpe-tokenizer</id><content type="html" xml:base="https://debtirthasaha.github.io/blog/2026/bpe-tokenizer/"><![CDATA[<p>A byte-pair-encoding tokenizer in pure Python on byte arrays. No NumPy, no neural net, no gradients. The trained tokenizer is two dicts:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">merges</span><span class="p">:</span> <span class="p">{(</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">):</span> <span class="nb">int</span><span class="p">}</span>    <span class="c1"># the parameters of the tokenizer
</span><span class="n">vocab</span><span class="p">:</span>  <span class="p">{</span><span class="nb">int</span><span class="p">:</span> <span class="nb">bytes</span><span class="p">}</span>         <span class="c1"># derived from merges, used to decode
</span></code></pre></div></div> <p>That’s the entire model. Then a long second half on what the tokenizer makes weird about LLMs: SolidGoldMagikarp, spelling failures, arithmetic failures, and the encode/decode asymmetry.</p> <h2 id="why-tokenization-exists">Why tokenization exists</h2> <p>LLMs eat integers, not text. The tokenizer maps strings ↔ integer sequences.</p> <p>Two obvious approaches, both bad:</p> <ul> <li><strong>Unicode code points as tokens.</strong> ~150K possible code points → vocab too large. The Unicode standard also keeps changing — not stable.</li> <li><strong>Raw UTF-8 bytes.</strong> Vocab is a clean 256, but every text becomes 3-4× longer. Attention is <code class="language-plaintext highlighter-rouge">O(T²)</code>. Long sequences blow up compute and exhaust context length.</li> </ul> <p>BPE: start at 256 raw bytes, iteratively merge the most frequent adjacent pair into a new token. Sequences shrink, vocab grows in a controlled way, you stop whenever you like. Vocab size is now a tunable hyperparameter.</p> <p>GPT-2 uses ~50K. GPT-4 uses ~100K. Llama 2 uses ~32K.</p> <h2 id="utf-8-in-one-paragraph">UTF-8 in one paragraph</h2> <p>UTF-8 encodes each Unicode code point as 1-4 bytes. ASCII is 1 byte (compatible with the old world). CJK ideographs are 3 bytes. Most emoji are 4. Crucially, not every byte sequence is valid UTF-8 — <code class="language-plaintext highlighter-rouge">b'\x80'</code> alone is not a legal start byte. This matters when we look at encode/decode round-tripping.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="sh">"</span><span class="s">Hello</span><span class="sh">"</span><span class="p">.</span><span class="nf">encode</span><span class="p">(</span><span class="sh">'</span><span class="s">utf-8</span><span class="sh">'</span><span class="p">)</span>   <span class="c1"># b'Hello'             (5 bytes)
</span><span class="sh">"</span><span class="s">안</span><span class="sh">"</span><span class="p">.</span><span class="nf">encode</span><span class="p">(</span><span class="sh">'</span><span class="s">utf-8</span><span class="sh">'</span><span class="p">)</span>       <span class="c1"># b'\xec\x95\x88'      (3 bytes)
</span><span class="sh">"</span><span class="s">🌊</span><span class="sh">"</span><span class="p">.</span><span class="nf">encode</span><span class="p">(</span><span class="sh">'</span><span class="s">utf-8</span><span class="sh">'</span><span class="p">)</span>       <span class="c1"># b'\xf0\x9f\x8c\x8a'  (4 bytes)
</span></code></pre></div></div> <p>GPT-2, GPT-4, and Llama all run BPE on UTF-8 bytes (byte-level BPE). Sentencepiece runs BPE on code points and falls back to bytes only for rare ones — clunkier but you’ll meet it in Llama and Mistral.</p> <h2 id="the-bpe-algorithm">The BPE algorithm</h2> <p>Given a sequence of token IDs:</p> <ol> <li>Count adjacent pairs.</li> <li>Pick the most frequent pair.</li> <li>Mint a new token ID (256, 257, 258, …).</li> <li>Replace every occurrence of the pair with the new ID.</li> <li>Record the merge.</li> </ol> <p>Repeat N times. The dict of recorded merges <em>is</em> the trained tokenizer.</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Start (vocab=4):  a a b d a a b d a a b              (length 11)
Most freq: (a,a) → mint Z
Round 1:          Z b d Z b d Z b                    (length 9, vocab 5)
Most freq: (Z,b) → mint Y       ← Z is brand new but already mergeable
Round 2:          Y d Y d Y                          (length 5, vocab 6)
</code></pre></div></div> <p>Token 256 can participate in the round-2 merge that creates token 257. BPE is hierarchical — merges form a forest where new merges build on top of old ones. This hierarchy is <em>load-bearing</em> and shows up again when we build <code class="language-plaintext highlighter-rouge">vocab</code> from <code class="language-plaintext highlighter-rouge">merges</code> and again when we encode.</p> <h2 id="implementation">Implementation</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_stats</span><span class="p">(</span><span class="n">ids</span><span class="p">):</span>
    <span class="n">counts</span> <span class="o">=</span> <span class="p">{}</span>
    <span class="k">for</span> <span class="n">pair</span> <span class="ow">in</span> <span class="nf">zip</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">ids</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
        <span class="n">counts</span><span class="p">[</span><span class="n">pair</span><span class="p">]</span> <span class="o">=</span> <span class="n">counts</span><span class="p">.</span><span class="nf">get</span><span class="p">(</span><span class="n">pair</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span>
    <span class="k">return</span> <span class="n">counts</span>

<span class="k">def</span> <span class="nf">merge</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">pair</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span>
    <span class="n">newids</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="k">while</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="nf">len</span><span class="p">(</span><span class="n">ids</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="nf">len</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">ids</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">==</span> <span class="n">pair</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">and</span> <span class="n">ids</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="n">pair</span><span class="p">[</span><span class="mi">1</span><span class="p">]:</span>
            <span class="n">newids</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span>
            <span class="n">i</span> <span class="o">+=</span> <span class="mi">2</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">newids</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">ids</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>
            <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
    <span class="k">return</span> <span class="n">newids</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">zip(ids, ids[1:])</code> is the idiomatic way to iterate adjacent pairs. The <code class="language-plaintext highlighter-rouge">i &lt; len(ids) - 1</code> check must come first — otherwise <code class="language-plaintext highlighter-rouge">ids[i+1]</code> on the last element raises <code class="language-plaintext highlighter-rouge">IndexError</code>. Python’s <code class="language-plaintext highlighter-rouge">and</code> short-circuits, so the bounds check before the comparison is the fix.</p> <p>Training is the loop:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">tokens</span> <span class="o">=</span> <span class="nf">list</span><span class="p">(</span><span class="n">text</span><span class="p">.</span><span class="nf">encode</span><span class="p">(</span><span class="sh">'</span><span class="s">utf-8</span><span class="sh">'</span><span class="p">))</span>   <span class="c1"># ints in [0, 255]
</span><span class="n">ids</span>    <span class="o">=</span> <span class="nf">list</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
<span class="n">merges</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">num_merges</span><span class="p">):</span>
    <span class="n">stats</span> <span class="o">=</span> <span class="nf">get_stats</span><span class="p">(</span><span class="n">ids</span><span class="p">)</span>
    <span class="n">pair</span>  <span class="o">=</span> <span class="nf">max</span><span class="p">(</span><span class="n">stats</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">stats</span><span class="p">.</span><span class="n">get</span><span class="p">)</span>
    <span class="n">idx</span>   <span class="o">=</span> <span class="mi">256</span> <span class="o">+</span> <span class="n">i</span>
    <span class="n">ids</span>   <span class="o">=</span> <span class="nf">merge</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">pair</span><span class="p">,</span> <span class="n">idx</span><span class="p">)</span>
    <span class="n">merges</span><span class="p">[</span><span class="n">pair</span><span class="p">]</span> <span class="o">=</span> <span class="n">idx</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">list(text.encode('utf-8'))</code> because iterating a <code class="language-plaintext highlighter-rouge">bytes</code> object yields ints in Python 3, so <code class="language-plaintext highlighter-rouge">list(bytes_obj)</code> is a flat list of ints in <code class="language-plaintext highlighter-rouge">[0, 255]</code>. We keep <code class="language-plaintext highlighter-rouge">tokens</code> untouched for the compression-ratio report and mutate <code class="language-plaintext highlighter-rouge">ids</code>.</p> <p>20 merges on 20868 bytes of text: down to 16154 tokens, <strong>1.29× compression</strong>. The first pair selected is <code class="language-plaintext highlighter-rouge">(101, 32) = 'e '</code> — words ending in <code class="language-plaintext highlighter-rouge">e</code> followed by a space.</p> <pre><code class="language-plotly">{"data":[{"x":[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20],"y":[20868,20276,19859,19465,19148,18865,18594,18326,18092,17873,17671,17469,17300,17136,16975,16826,16689,16552,16417,16285,16154],"type":"scatter","mode":"lines+markers","line":{"color":"#EF553B","width":2},"marker":{"size":7},"text":["start","'e '","'in'","'s '","'th'","'er'","'t '","'co'","', '","'an'","'d '","'or'","'ar'","'en'","'&lt;g&gt; on '","'&lt;d&gt; on '","'y '","'al'","'on'","'&lt;256&gt;&lt;256&gt; '","'ac'"],"hovertemplate":"merge %{x}&lt;br&gt;pair: %{text}&lt;br&gt;tokens: %{y}&lt;extra&gt;&lt;/extra&gt;","name":"sequence length"}],"layout":{"title":{"text":"Sequence length over 20 BPE merges (20,868 bytes -&gt; 16,154 tokens)"},"xaxis":{"title":"merge step","dtick":2},"yaxis":{"title":"tokens","range":[15500,21500]},"height":420,"margin":{"l":70,"r":30,"t":60,"b":50},"showlegend":false}}
</code></pre> <p>The biggest drops come early — <code class="language-plaintext highlighter-rouge">'e '</code> alone saves 592 tokens. Returns diminish: by merge 20, each new pair removes ~130 tokens. The dominant merges are mostly English space-suffix bigrams (<code class="language-plaintext highlighter-rouge">'e '</code>, <code class="language-plaintext highlighter-rouge">'s '</code>, <code class="language-plaintext highlighter-rouge">'t '</code>, <code class="language-plaintext highlighter-rouge">'d '</code>, <code class="language-plaintext highlighter-rouge">'y '</code>, <code class="language-plaintext highlighter-rouge">', '</code>) and high-frequency root pairs (<code class="language-plaintext highlighter-rouge">'in'</code>, <code class="language-plaintext highlighter-rouge">'th'</code>, <code class="language-plaintext highlighter-rouge">'er'</code>, <code class="language-plaintext highlighter-rouge">'an'</code>, <code class="language-plaintext highlighter-rouge">'or'</code>, <code class="language-plaintext highlighter-rouge">'ar'</code>, <code class="language-plaintext highlighter-rouge">'en'</code>, <code class="language-plaintext highlighter-rouge">'al'</code>, <code class="language-plaintext highlighter-rouge">'on'</code>, <code class="language-plaintext highlighter-rouge">'ac'</code>). BPE rediscovers the morphological structure of English suffixes and roots from byte frequencies alone.</p> <h2 id="building-vocab-from-merges">Building <code class="language-plaintext highlighter-rouge">vocab</code> from <code class="language-plaintext highlighter-rouge">merges</code></h2> <p>Decode needs to know what each token ID is in bytes. Derive it:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">_build_vocab</span><span class="p">(</span><span class="n">merges</span><span class="p">):</span>
    <span class="n">vocab</span> <span class="o">=</span> <span class="p">{</span><span class="n">i</span><span class="p">:</span> <span class="nf">bytes</span><span class="p">([</span><span class="n">i</span><span class="p">])</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">256</span><span class="p">)}</span>
    <span class="nf">for </span><span class="p">(</span><span class="n">p0</span><span class="p">,</span> <span class="n">p1</span><span class="p">),</span> <span class="n">idx</span> <span class="ow">in</span> <span class="n">merges</span><span class="p">.</span><span class="nf">items</span><span class="p">():</span>
        <span class="n">vocab</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">vocab</span><span class="p">[</span><span class="n">p0</span><span class="p">]</span> <span class="o">+</span> <span class="n">vocab</span><span class="p">[</span><span class="n">p1</span><span class="p">]</span>
    <span class="k">return</span> <span class="n">vocab</span>
</code></pre></div></div> <p>The insertion-order requirement is real. Token 258 might be <code class="language-plaintext highlighter-rouge">vocab[256] + vocab[257]</code> — both must already exist when we look them up. Python 3.7+ guarantees <code class="language-plaintext highlighter-rouge">dict.items()</code> iterates in insertion order. In Python ≤3.6 this code silently produces wrong vocab.</p> <h2 id="decode">Decode</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="n">ids</span><span class="p">,</span> <span class="n">vocab</span><span class="p">):</span>
    <span class="n">raw_bytes</span> <span class="o">=</span> <span class="sa">b</span><span class="sh">""</span><span class="p">.</span><span class="nf">join</span><span class="p">(</span><span class="n">vocab</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">ids</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">raw_bytes</span><span class="p">.</span><span class="nf">decode</span><span class="p">(</span><span class="sh">'</span><span class="s">utf-8</span><span class="sh">'</span><span class="p">,</span> <span class="n">errors</span><span class="o">=</span><span class="sh">'</span><span class="s">replace</span><span class="sh">'</span><span class="p">)</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">errors='replace'</code> because not every byte sequence is valid UTF-8. If the LLM emits a sequence of token IDs whose concatenated bytes don’t form a valid UTF-8 string, <code class="language-plaintext highlighter-rouge">errors='strict'</code> raises <code class="language-plaintext highlighter-rouge">UnicodeDecodeError</code> and the inference call crashes. <code class="language-plaintext highlighter-rouge">'replace'</code> substitutes U+FFFD (the <code class="language-plaintext highlighter-rouge">?</code>-in-a-diamond character) and keeps going. OpenAI’s released code does the same.</p> <h2 id="encode">Encode</h2> <p>The trick: apply merges <em>in the same order they were created during training</em>. Get this wrong and you produce a different token sequence than the trained vocabulary expects.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="n">text</span><span class="p">,</span> <span class="n">merges</span><span class="p">):</span>
    <span class="n">tokens</span> <span class="o">=</span> <span class="nf">list</span><span class="p">(</span><span class="n">text</span><span class="p">.</span><span class="nf">encode</span><span class="p">(</span><span class="sh">'</span><span class="s">utf-8</span><span class="sh">'</span><span class="p">))</span>
    <span class="k">if</span> <span class="nf">len</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span> <span class="o">&lt;</span> <span class="mi">2</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">tokens</span>

    <span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
        <span class="n">stats</span> <span class="o">=</span> <span class="nf">get_stats</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
        <span class="n">pair</span>  <span class="o">=</span> <span class="nf">min</span><span class="p">(</span><span class="n">stats</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="k">lambda</span> <span class="n">p</span><span class="p">:</span> <span class="n">merges</span><span class="p">.</span><span class="nf">get</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="nf">float</span><span class="p">(</span><span class="sh">'</span><span class="s">inf</span><span class="sh">'</span><span class="p">)))</span>
        <span class="k">if</span> <span class="n">pair</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">merges</span><span class="p">:</span>
            <span class="k">break</span>
        <span class="n">tokens</span> <span class="o">=</span> <span class="nf">merge</span><span class="p">(</span><span class="n">tokens</span><span class="p">,</span> <span class="n">pair</span><span class="p">,</span> <span class="n">merges</span><span class="p">[</span><span class="n">pair</span><span class="p">])</span>
    <span class="k">return</span> <span class="n">tokens</span>
</code></pre></div></div> <p>Three subtleties:</p> <ul> <li><strong><code class="language-plaintext highlighter-rouge">min</code>, not <code class="language-plaintext highlighter-rouge">max</code>.</strong> Training picked the most <em>frequent</em> pair. Encoding picks the lowest-rank merge (the one created earliest). Why: later merges depend on earlier ones existing as their building blocks. If the text contains a pair that became token 258 = <code class="language-plaintext highlighter-rouge">(256, 257)</code>, then tokens 256 and 257 must merge first. Always do the earliest available merge.</li> <li><strong><code class="language-plaintext highlighter-rouge">float('inf')</code> as fallback.</strong> Pairs not in <code class="language-plaintext highlighter-rouge">merges</code> get rank infinity. <code class="language-plaintext highlighter-rouge">min</code> never picks them. The loop terminates when every remaining pair has rank infinity.</li> <li><strong><code class="language-plaintext highlighter-rouge">len(tokens) &lt; 2</code> guard.</strong> Empty or single-char strings give empty <code class="language-plaintext highlighter-rouge">stats</code> and <code class="language-plaintext highlighter-rouge">min({})</code> raises <code class="language-plaintext highlighter-rouge">ValueError</code>.</li> </ul> <h2 id="the-encodedecodex-asymmetry">The encode(decode(x)) asymmetry</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nf">decode</span><span class="p">(</span><span class="nf">encode</span><span class="p">(</span><span class="sh">"</span><span class="s">hello world</span><span class="sh">"</span><span class="p">))</span> <span class="o">==</span> <span class="sh">"</span><span class="s">hello world</span><span class="sh">"</span>   <span class="c1"># always
</span><span class="nf">encode</span><span class="p">(</span><span class="nf">decode</span><span class="p">([</span><span class="mi">128</span><span class="p">]))</span>         <span class="o">==</span> <span class="p">[</span><span class="mi">128</span><span class="p">]</span>            <span class="c1"># NOT guaranteed
</span></code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">decode([128])</code>: byte 128 alone is <code class="language-plaintext highlighter-rouge">b'\x80'</code>, an invalid UTF-8 start byte. With <code class="language-plaintext highlighter-rouge">errors='replace'</code>, decode returns the replacement character. Re-encoding the replacement character gives different bytes than <code class="language-plaintext highlighter-rouge">[128]</code>.</p> <p>Forward (text → tokens → text) is always lossless. Reverse may not be. If your code ever relies on <code class="language-plaintext highlighter-rouge">encode(decode(x)) == x</code>, it has a latent bug.</p> <h2 id="gpt-2-regex-pre-splitting">GPT-2 regex pre-splitting</h2> <p>Plain BPE happily merges across word and punctuation boundaries: <code class="language-plaintext highlighter-rouge">dog.</code>, <code class="language-plaintext highlighter-rouge">dog!</code>, <code class="language-plaintext highlighter-rouge">dog?</code>, <code class="language-plaintext highlighter-rouge">dog,</code> end up as separate tokens. GPT-2 prevents this by <em>forcing</em> a split before BPE runs, using a regex:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="n">regex</span> <span class="k">as</span> <span class="n">re</span>   <span class="c1"># pip install regex — NOT the stdlib re
</span>
<span class="n">GPT2_PATTERN</span> <span class="o">=</span> <span class="sa">r</span><span class="sh">"""'</span><span class="s">s|</span><span class="sh">'</span><span class="s">t|</span><span class="sh">'</span><span class="s">re|</span><span class="sh">'</span><span class="s">ve|</span><span class="sh">'</span><span class="s">m|</span><span class="sh">'</span><span class="s">ll|</span><span class="sh">'</span><span class="s">d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+</span><span class="sh">"""</span>

<span class="n">re</span><span class="p">.</span><span class="nf">findall</span><span class="p">(</span><span class="n">GPT2_PATTERN</span><span class="p">,</span> <span class="sh">"</span><span class="s">Hello world! I</span><span class="sh">'</span><span class="s">m fine.</span><span class="sh">"</span><span class="p">)</span>
<span class="c1"># ['Hello', ' world', '!', " I", "'m", ' fine', '.']
</span></code></pre></div></div> <p>Pattern parts:</p> <ul> <li><code class="language-plaintext highlighter-rouge">'s|'t|'re|'ve|'m|'ll|'d</code> — English contraction suffixes</li> <li>` ?\p{L}+` — optional space + Unicode letters (so “ world” is one chunk)</li> <li>` ?\p{N}+` — optional space + Unicode numbers</li> <li>` ?[^\s\p{L}\p{N}]+` — optional space + punctuation</li> <li><code class="language-plaintext highlighter-rouge">\s+(?!\S)</code> and <code class="language-plaintext highlighter-rouge">\s+</code> — whitespace runs</li> </ul> <p>BPE runs <em>per chunk</em> and IDs are concatenated. Two known GPT-2 bugs fixed in GPT-4: only ASCII apostrophe (curly quotes break), and not case-insensitive (<code class="language-plaintext highlighter-rouge">DON'T</code> doesn’t split). GPT-4 uses <code class="language-plaintext highlighter-rouge">(?i:...)</code> and properly handles Unicode apostrophes.</p> <p>GPT-4 also caps numbers at 3 digits per chunk. With arbitrary-length number tokens, <code class="language-plaintext highlighter-rouge">12345</code> might be 1 token but <code class="language-plaintext highlighter-rouge">12346</code> might be <code class="language-plaintext highlighter-rouge">[12, 346]</code> — totally inconsistent splits that wreck digit-position arithmetic. The 3-digit cap forces predictable behavior.</p> <h2 id="whitespace-in-code-a-gpt-2-disaster-a-gpt-4-fix">Whitespace in code: a GPT-2 disaster, a GPT-4 fix</h2> <p>GPT-2 makes every space its own token. Four spaces of Python indent = 4 tokens. GPT-4 merges runs of whitespace into single tokens, roughly doubling effective context for indented code. This is one of the largest single improvements in GPT-4’s tokenizer.</p> <h2 id="special-tokens">Special tokens</h2> <p>Outside BPE entirely. Matched by string before BPE runs, then assigned a hardcoded ID.</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>GPT-2 vocab layout:
  0..255         raw byte tokens
  256..50255     50,000 BPE merges
  50256          &lt;|endoftext|&gt;           ← only special token
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">&lt;|endoftext|&gt;</code> marks document boundaries during training:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>doc1_tokens + [50256] + doc2_tokens + [50256] + doc3_tokens + ...
</code></pre></div></div> <p>The model <em>learns</em> what 50256 means. It’s special only in that the tokenizer never produces it from regular text — only when explicitly requested.</p> <p>GPT-4 adds FIM tokens (<code class="language-plaintext highlighter-rouge">&lt;|fim_prefix|&gt;</code>, <code class="language-plaintext highlighter-rouge">&lt;|fim_middle|&gt;</code>, <code class="language-plaintext highlighter-rouge">&lt;|fim_suffix|&gt;</code>) for code completion and <code class="language-plaintext highlighter-rouge">&lt;|im_start|&gt;</code>/<code class="language-plaintext highlighter-rouge">&lt;|im_end|&gt;</code> for chat boundaries.</p> <p>Adding a special token to a pretrained model is surgery: resize the embedding table by N rows, resize the LM head by N columns, typically freeze the base and train only the new slices.</p> <p>Security: <code class="language-plaintext highlighter-rouge">tiktoken.encode(user_input, allowed_special="all")</code> lets users inject <code class="language-plaintext highlighter-rouge">&lt;|endoftext|&gt;</code> and confuse boundary logic. Default is strict — opt in only when you know the input is trusted.</p> <h2 id="vocab-size-the-only-two-places-it-shows-up">Vocab size: the only two places it shows up</h2> <p><code class="language-plaintext highlighter-rouge">vocab_size</code> appears in exactly two places in the model:</p> <ol> <li>Token embedding: <code class="language-plaintext highlighter-rouge">nn.Embedding(vocab_size, n_embd)</code></li> <li>LM head: <code class="language-plaintext highlighter-rouge">nn.Linear(n_embd, vocab_size)</code></li> </ol> <p>Everything else (attention, MLP, LayerNorm) is independent of vocab size.</p> <table> <thead> <tr> <th> </th> <th>Small (256)</th> <th>Large (1M)</th> </tr> </thead> <tbody> <tr> <td>Sequence length</td> <td>very long</td> <td>short</td> </tr> <tr> <td>Embed / head size</td> <td>tiny</td> <td>huge</td> </tr> <tr> <td>Per-token signal</td> <td>strong (every token seen often)</td> <td>weak (rare tokens undertrained)</td> </tr> </tbody> </table> <p>Sweet spot is <strong>32k–100k</strong>.</p> <h2 id="solidgoldmagikarp">SolidGoldMagikarp</h2> <p>The famous tokenization bug. Mechanism:</p> <ol> <li>OpenAI trained the <strong>tokenizer</strong> on a dataset that included Reddit.</li> <li>User <code class="language-plaintext highlighter-rouge">SolidGoldMagikarp</code> posted enough that BPE merged the username into a single token.</li> <li>OpenAI then trained the <strong>language model</strong> on a different, filtered dataset that didn’t include those Reddit posts.</li> <li>The token exists in vocab but its embedding row was never updated by gradient descent. It’s still the random initialization.</li> <li>At inference, typing <code class="language-plaintext highlighter-rouge">SolidGoldMagikarp</code> loads that random embedding into the model. Undefined behavior.</li> </ol> <p>Observed: the model evades, hallucinates, insults the user, bypasses safety, or gets stuck looping.</p> <p>The C analogy: reading uninitialized memory. The slot exists, but no one ever wrote a meaningful value to it.</p> <p>Root cause is the dataset mismatch between tokenizer training and LM training. Prevention is to use the same (or strictly overlapping) datasets, or audit per-token activation counts after pretraining and remove zero-activation tokens.</p> <p>Other examples: <code class="language-plaintext highlighter-rouge">" davidjl"</code>, <code class="language-plaintext highlighter-rouge">" TheNitromeFan"</code>, <code class="language-plaintext highlighter-rouge">" RandomRedditorWithNo"</code>. All Reddit usernames or fragments.</p> <h2 id="spelling-reversal-arithmetic--all-tokenization">Spelling, reversal, arithmetic — all tokenization</h2> <p><strong>Why GPT can’t count the L’s in <code class="language-plaintext highlighter-rouge">DefaultCellStyle</code>.</strong> <code class="language-plaintext highlighter-rouge">DefaultCellStyle</code> is a single token in GPT-4’s vocab. The model sees one opaque ID, not the constituent letters. Asking how many L’s are inside is like asking how many L’s are in the integer <code class="language-plaintext highlighter-rouge">28139</code>.</p> <p>Workaround for character-level tasks: prompt the model to first split with spaces (<code class="language-plaintext highlighter-rouge">D e f a u l t C e l l S t y l e</code>) so each character becomes its own token.</p> <p><strong>Why arithmetic is brittle.</strong> Number tokenization in GPT-2 is essentially arbitrary. <code class="language-plaintext highlighter-rouge">1024</code> might be one token, <code class="language-plaintext highlighter-rouge">123456</code> might be <code class="language-plaintext highlighter-rouge">[12, 3456]</code>. Digit position is unaligned across examples. Carrying digits requires aligning positions, which is structurally impossible when chunks are random. GPT-4’s 3-digit cap helps. Llama uses <code class="language-plaintext highlighter-rouge">split_digits=True</code> (one digit per token).</p> <p><strong>Why non-English is undertrained.</strong> Korean text takes ~3× more tokens than English for the same content. Less training signal per concept, less effective context.</p> <p><strong>Trailing whitespace warning.</strong> In training, <code class="language-plaintext highlighter-rouge">" the"</code> is a single token (space + “the”). If you end a prompt with a bare space, the last token is a lone space — almost never seen during training. The model is now out-of-distribution. Don’t end prompts with spaces.</p> <p><strong>Partial token glitches.</strong> Completing <code class="language-plaintext highlighter-rouge">"DefaultCellSty"</code> (a partial token) can produce immediate end-of-text, garbage, or content-policy warnings, because that exact subsequence rarely appears in training. tiktoken’s source has an entire “unstable tokens” module for this case.</p> <h2 id="token-economy">Token economy</h2> <p>Same data, different format, different token count. JSON has overhead from <code class="language-plaintext highlighter-rouge">{</code>, <code class="language-plaintext highlighter-rouge">}</code>, <code class="language-plaintext highlighter-rouge">"</code>, <code class="language-plaintext highlighter-rouge">:</code>, commas. YAML strips most of it.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="sh">'</span><span class="s">{</span><span class="sh">"</span><span class="s">name</span><span class="sh">"</span><span class="s">: </span><span class="sh">"</span><span class="s">Alice</span><span class="sh">"</span><span class="s">, </span><span class="sh">"</span><span class="s">age</span><span class="sh">"</span><span class="s">: 30}</span><span class="sh">'</span>    <span class="c1"># ~12 tokens
</span><span class="sh">'</span><span class="s">name: Alice</span><span class="se">\n</span><span class="s">age: 30</span><span class="sh">'</span>             <span class="c1"># ~8 tokens
</span></code></pre></div></div> <p>~15% savings going from JSON to YAML for the same content. Multiply across an API bill or a long-context model and it matters.</p> <p><a href="https://tiktokenizer.vercel.app">tiktokenizer.vercel.app</a> shows you how anything tokenizes.</p> <h2 id="bugs-to-remember">Bugs to remember</h2> <table> <thead> <tr> <th>#</th> <th>Bug</th> <th>Symptom</th> <th>Fix</th> </tr> </thead> <tbody> <tr> <td>1</td> <td><code class="language-plaintext highlighter-rouge">merge()</code> no bounds check</td> <td><code class="language-plaintext highlighter-rouge">IndexError</code> on last element</td> <td>check <code class="language-plaintext highlighter-rouge">i &lt; len(ids) - 1</code> first</td> </tr> <tr> <td>2</td> <td><code class="language-plaintext highlighter-rouge">decode()</code> with <code class="language-plaintext highlighter-rouge">errors='strict'</code></td> <td>crashes on invalid byte sequences</td> <td><code class="language-plaintext highlighter-rouge">errors='replace'</code></td> </tr> <tr> <td>3</td> <td><code class="language-plaintext highlighter-rouge">encode()</code> on empty/single-char</td> <td><code class="language-plaintext highlighter-rouge">min({})</code> raises <code class="language-plaintext highlighter-rouge">ValueError</code></td> <td>early return for <code class="language-plaintext highlighter-rouge">len &lt; 2</code></td> </tr> <tr> <td>4</td> <td>Wrong-order vocab build</td> <td>Python ≤3.6 vocab is silently wrong</td> <td>require 3.7+</td> </tr> <tr> <td>5</td> <td>GPT-2 ASCII-only apostrophe</td> <td>curly-quote contractions break</td> <td>use GPT-4 pattern</td> </tr> <tr> <td>6</td> <td>GPT-2 case-sensitive split</td> <td><code class="language-plaintext highlighter-rouge">DON'T</code> doesn’t split</td> <td>use GPT-4 pattern with <code class="language-plaintext highlighter-rouge">(?i:...)</code></td> </tr> <tr> <td>7</td> <td>Special tokens in user input</td> <td>boundary confusion / jailbreak</td> <td>restrict <code class="language-plaintext highlighter-rouge">allowed_special</code></td> </tr> <tr> <td>8</td> <td>Trailing whitespace in prompt</td> <td>OOD final token</td> <td>don’t end with space</td> </tr> <tr> <td>9</td> <td>Tokenizer dataset ≠ LM dataset</td> <td>SolidGoldMagikarp glitch tokens</td> <td>same datasets, audit activations</td> </tr> <tr> <td>10</td> <td><code class="language-plaintext highlighter-rouge">encode(decode(x)) == x</code> assumption</td> <td>silent breakage</td> <td>only <code class="language-plaintext highlighter-rouge">decode(encode(x))</code> is safe</td> </tr> </tbody> </table> <p>Code: <a href="https://github.com/debtirthasaha/bpe-tokenizer">github.com/debtirthasaha/bpe-tokenizer</a>. Reference: <a href="https://github.com/karpathy/minbpe">karpathy/minbpe</a>.</p>]]></content><author><name></name></author><category term="nlp"/><category term="tokenization"/><category term="bpe"/><category term="gpt"/><summary type="html"><![CDATA[Byte-pair encoding implemented in pure Python. Plus SolidGoldMagikarp, the encode/decode asymmetry, and a list of LLM weirdness all caused by the tokenizer.]]></summary></entry><entry><title type="html">Birkhoff in 8.7 KB</title><link href="https://debtirthasaha.github.io/blog/2026/birkhoff-in-8kb/" rel="alternate" type="text/html" title="Birkhoff in 8.7 KB"/><published>2026-04-20T10:00:00+00:00</published><updated>2026-04-20T10:00:00+00:00</updated><id>https://debtirthasaha.github.io/blog/2026/birkhoff-in-8kb</id><content type="html" xml:base="https://debtirthasaha.github.io/blog/2026/birkhoff-in-8kb/"><![CDATA[<p><a href="https://competition.sair.foundation/competitions/mathematics-distillation-challenge-equational-theories-stage1/overview">SAIR’s Equational Theories competition (Stage 1)</a> — organized by Damek Davis (UPenn), Terence Tao (UCLA), and the SAIR Foundation — gives you a pair of equations over a single binary operator <code class="language-plaintext highlighter-rouge">*</code>:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>E1:  L1 = R1
E2:  L2 = R2
</code></pre></div></div> <p>All variables are universally quantified. A <em>magma</em> is just a set with a binary operation — no axioms beyond closure. The question: does every magma that satisfies <code class="language-plaintext highlighter-rouge">E1</code> necessarily satisfy <code class="language-plaintext highlighter-rouge">E2</code>? Output <code class="language-plaintext highlighter-rouge">true</code> or <code class="language-plaintext highlighter-rouge">false</code>.</p> <p>The dataset is drawn from Tao’s <a href="https://teorth.github.io/equational_theories/implications/">Equational Theories Project</a>: 4694 magma laws, giving <code class="language-plaintext highlighter-rouge">4694 × 4693 = 22,028,942</code> ordered implications. The Stage 1 public subsets are <code class="language-plaintext highlighter-rouge">normal</code> (1000 problems, 50/50 true/false split), <code class="language-plaintext highlighter-rouge">hard1</code> (69), <code class="language-plaintext highlighter-rouge">hard2</code> (200, 50/50), <code class="language-plaintext highlighter-rouge">hard3</code> (400, 195/205) and an <code class="language-plaintext highlighter-rouge">order5</code> research subset.</p> <p>The setup is a follow-up to Honda, Murakami &amp; Zhang (2025), <em>Distilling Many-Shot In-Context Learning into a Cheat Sheet</em>: instead of having one model write the cheatsheet, SAIR runs an open competition so the cheatsheet is <em>discovered</em> across submissions. You submit a single Markdown file — a prompt template the harness fills in with the two equations and sends to a fixed set of frozen models (GPT-OSS 120B, Llama 3.3 70B, Gemma 4 31B). No fine-tuning, no agents, no tool calls, no chain-of-thought tax beyond what fits in the 8192-token completion budget. Hard cap on cheatsheet size: 10 KB. Scoring on Stage 1 is correctness (accuracy and F1) only — no proof artefacts, no calibrated probabilities. Those come in Stage 2 (Lean proofs, counterexamples, calibration).</p> <p>The cheatsheet I submitted is <strong>8.71 KB</strong> — under the cap with room to spare. It replaces free-form reasoning with a 9-magma closed-form decision procedure. The headline result: Gemma 4 31B running this prompt beat GPT-OSS 120B by 16 accuracy points on the hardest band.</p> <h2 id="the-math-birkhoff-completeness">The math: Birkhoff completeness</h2> <p>By Birkhoff’s theorem, <code class="language-plaintext highlighter-rouge">E1 ⊨ E2</code> (E1 semantically entails E2) iff E2 is derivable from E1 by equational logic — reflexivity, symmetry, transitivity, congruence, substitution. Equivalently, <code class="language-plaintext highlighter-rouge">E1 ⊨ E2</code> iff there is <em>no</em> magma satisfying E1 but not E2.</p> <p>This gives you exactly two sound moves:</p> <ul> <li><strong>Return <code class="language-plaintext highlighter-rouge">false</code></strong>: exhibit a specific magma where E1 holds and E2 fails.</li> <li><strong>Return <code class="language-plaintext highlighter-rouge">true</code></strong>: derive E2 from E1 (or argue no counterexample exists).</li> </ul> <p>Anything else is not a proof. In particular, “the equations look similar / share variables / I don’t see a derivation” is not sound. This is where LLMs get into trouble.</p> <h2 id="why-free-form-prompting-struggles">Why free-form prompting struggles</h2> <p>Ask an LLM “does E1 imply E2?” and it will produce English reasoning that <em>looks</em> like a proof. Sometimes it is one. Often it is structural pattern-matching that happens to be wrong on subtle cases: an implication that needs you to <em>construct</em> a specific failing magma, or a non-implication where the equation pair shares enough surface structure that the model is fooled.</p> <p>The fix is to remove the freedom. Instead of asking the model to reason, give it a <em>closed-form procedure</em> over a finite catalog of magmas, and a hard rule that the only way to return <code class="language-plaintext highlighter-rouge">false</code> is to point to a specific catalog entry that refutes.</p> <h2 id="the-9-magmas">The 9 magmas</h2> <p>A magma is <code class="language-plaintext highlighter-rouge">(M, *)</code>. For each one, the cheatsheet supplies a <em>closed-form predicate</em> on the equation tree that decides whether the magma satisfies a given equation. No enumeration over <code class="language-plaintext highlighter-rouge">M</code>, no search — a direct formula.</p> <table> <thead> <tr> <th>#</th> <th>Magma <code class="language-plaintext highlighter-rouge">a*b</code></th> <th>Satisfies <code class="language-plaintext highlighter-rouge">L = R</code> iff</th> </tr> </thead> <tbody> <tr> <td>0</td> <td><code class="language-plaintext highlighter-rouge">a*b = b</code> (right-projection)</td> <td><code class="language-plaintext highlighter-rouge">rm(L) == rm(R)</code></td> </tr> <tr> <td>1</td> <td><code class="language-plaintext highlighter-rouge">a*b = a</code> (left-projection)</td> <td><code class="language-plaintext highlighter-rouge">lm(L) == lm(R)</code></td> </tr> <tr> <td>2</td> <td><code class="language-plaintext highlighter-rouge">a*b = c</code> (constant)</td> <td>both depths ≥ 1, or <code class="language-plaintext highlighter-rouge">L</code>, <code class="language-plaintext highlighter-rouge">R</code> same bare var</td> </tr> <tr> <td>3</td> <td><code class="language-plaintext highlighter-rouge">a*b = a + b</code> on ℤ/2 (XOR)</td> <td><code class="language-plaintext highlighter-rouge">count(v, L) ≡ count(v, R)</code> mod 2 ∀v</td> </tr> <tr> <td>4</td> <td><code class="language-plaintext highlighter-rouge">a*b = a + b</code> on ℤ/3</td> <td><code class="language-plaintext highlighter-rouge">count(v, L) ≡ count(v, R)</code> mod 3 ∀v</td> </tr> <tr> <td>5</td> <td><code class="language-plaintext highlighter-rouge">a*b = a + b</code> on ℤ</td> <td><code class="language-plaintext highlighter-rouge">count(v, L) == count(v, R)</code> ∀v</td> </tr> <tr> <td>6</td> <td><code class="language-plaintext highlighter-rouge">a*b = b + 1</code> on ℤ/3 (right-successor)</td> <td><code class="language-plaintext highlighter-rouge">rm(L) == rm(R)</code> and <code class="language-plaintext highlighter-rouge">drm(L) ≡ drm(R)</code> mod 3</td> </tr> <tr> <td>7</td> <td><code class="language-plaintext highlighter-rouge">a*b = a + 1</code> on ℤ/3 (left-successor)</td> <td><code class="language-plaintext highlighter-rouge">lm(L) == lm(R)</code> and <code class="language-plaintext highlighter-rouge">dlm(L) ≡ dlm(R)</code> mod 3</td> </tr> <tr> <td>8</td> <td><code class="language-plaintext highlighter-rouge">a*b = −a − b</code> on ℤ/3 (negation-sum)</td> <td><code class="language-plaintext highlighter-rouge">signed_count(v, L) ≡ signed_count(v, R)</code> mod 3 ∀v</td> </tr> </tbody> </table> <p>Where the tree primitives are:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>lm(v) = v,        lm(a*b) = lm(a)               # leftmost leaf
rm(v) = v,        rm(a*b) = rm(b)               # rightmost leaf
dlm(v) = 0,       dlm(a*b) = 1 + dlm(a)         # left-spine length
drm(v) = 0,       drm(a*b) = 1 + drm(b)         # right-spine length
count(v, t) = number of times v appears as a leaf in t
signed_count(v, t) = Σ (-1)^depth(leaf) over leaf occurrences of v
</code></pre></div></div> <p>Each row of the table is a one-line check the model can run by walking the equation tree once. No symbolic manipulation, no equational rewriting, no induction.</p> <h2 id="the-decision-procedure">The decision procedure</h2> <p>For each equation <code class="language-plaintext highlighter-rouge">E</code>, compute <code class="language-plaintext highlighter-rouge">sig(E) = (h0, h1, …, h8)</code> — the 9-bit vector of which catalog magmas satisfy E.</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>def implies(E1, E2):
    s1, s2 = sig(E1), sig(E2)
    for i in range(9):
        if s1[i] and not s2[i]:
            return "false"     # magma i refutes
    return "true"
</code></pre></div></div> <p>If any catalog magma satisfies E1 but falsifies E2, return <code class="language-plaintext highlighter-rouge">false</code> and the witness is that magma. Otherwise return <code class="language-plaintext highlighter-rouge">true</code> — the sound default when no catalog magma produces a refutation. Note that this <em>can</em> be wrong (the true answer might be <code class="language-plaintext highlighter-rouge">false</code> via some magma outside the catalog), but it’s wrong in the safe direction: the catalog produces no false <code class="language-plaintext highlighter-rouge">false</code>s.</p> <h2 id="refutation-discipline">Refutation discipline</h2> <p>The hardest part of getting the LLM to be sound is making it stop hallucinating counterexamples. The cheatsheet enforces:</p> <blockquote> <p>A refutation must name an index <code class="language-plaintext highlighter-rouge">i ∈ {0..8}</code> with <code class="language-plaintext highlighter-rouge">sig(E1)[i]=T</code> and <code class="language-plaintext highlighter-rouge">sig(E2)[i]=F</code>. Anything else is not a refutation. Do not infer <code class="language-plaintext highlighter-rouge">false</code> from structural similarity, shared letters, or the absence of an obvious derivation.</p> </blockquote> <p>And the cheatsheet’s last instruction before falling back to <code class="language-plaintext highlighter-rouge">true</code>:</p> <blockquote> <p>If either bit is uncertain after re-check, return <code class="language-plaintext highlighter-rouge">true</code>. Sound procedure: a hallucinated refutation is worse than a missed one, because a missed refutation at least falls back to the mathematically honest default.</p> </blockquote> <p>This asymmetric default is the lever. Free-form LLM reasoning falsely says <code class="language-plaintext highlighter-rouge">false</code> constantly. Constraining false to “name your <code class="language-plaintext highlighter-rouge">i</code>” cuts those errors almost entirely, at the cost of a few extra false-true answers (which the catalog can’t help anyway).</p> <h2 id="the-mod-3-magmas-are-where-errors-happen">The mod-3 magmas are where errors happen</h2> <p>Magmas 6, 7, 8 are the ones the cheatsheet spends the most space on, because they’re the ones models reliably get wrong. The error pattern: get <code class="language-plaintext highlighter-rouge">lm</code>/<code class="language-plaintext highlighter-rouge">rm</code> right, then either skip computing <code class="language-plaintext highlighter-rouge">dlm</code>/<code class="language-plaintext highlighter-rouge">drm</code> or compute them wrong, or get the mod-3 arithmetic wrong (signed counts can be negative; <code class="language-plaintext highlighter-rouge">−1 mod 3 = 2</code>).</p> <p>The fix is to inline worked examples for each, written so the model <em>has</em> to walk the tree. Example for magma 7:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Example A (TRUE): x = ((x*y)*z)*w
  lm(L) = x. For R, descend left: R → (x*y)*z → x*y → x. lm(R) = x. Match.
  dlm(L) = 0. dlm(R) counts those 3 left-descents = 3.
  0 mod 3 = 0, 3 mod 3 = 0. Equal. h7 = TRUE.

Example B (FALSE): x = (x*y)*z
  lm(L) = lm(R) = x. Match.
  dlm(L) = 0, dlm(R) = 2.
  0 mod 3 = 0, 2 mod 3 = 2. Differ. h7 = FALSE.
</code></pre></div></div> <p>The example doesn’t teach the model what <code class="language-plaintext highlighter-rouge">dlm</code> is. It teaches it that <em>it has to walk the tree</em> before answering, by being a worked instance with the walk shown.</p> <p>Magma 8 (signed counts mod 3) gets the most defensive treatment, because depth-parity arithmetic with negative residues is the most error-prone single check in the whole procedure.</p> <h2 id="the-mandatory-parse-step">The mandatory PARSE step</h2> <p>Before any rule fires, the cheatsheet requires the model to <em>explicitly produce</em> <code class="language-plaintext highlighter-rouge">lm</code>, <code class="language-plaintext highlighter-rouge">rm</code>, <code class="language-plaintext highlighter-rouge">dlm</code>, <code class="language-plaintext highlighter-rouge">drm</code>, depth, and the per-variable counts for both equation sides. Skipping this step is the single largest source of wrong answers — models guess <code class="language-plaintext highlighter-rouge">lm(R)</code> instead of descending the tree.</p> <p>Forcing the structured output reframes the problem. The model isn’t reasoning about magma implication anymore; it’s filling in a six-slot form, then running nine if-statements, then doing one loop with a hard stop condition. This format is <em>much</em> friendlier to the LLM substrate than free-form math.</p> <h2 id="results">Results</h2> <p>The Stage 1 leaderboard has <strong>1,061 participants</strong>. It scores on accuracy and F1 across three difficulty buckets — normal, hard, extra_hard — for each of the three frozen models. Restricted leaderboards (single model, or the order-5 research subset) score the same submission from a different angle.</p> <p>Cheatsheet size: 8.71 KB. Mean parse success: 100%. Mean per-query cost (across all models, all sets): roughly $0.0004–$0.0009 depending on model.</p> <h3 id="overall-leaderboard-all-models-all-sets--rank-85">Overall leaderboard (all models, all sets) — rank 85</h3> <table> <thead> <tr> <th>Model</th> <th>Set</th> <th>Accuracy</th> <th>F1</th> </tr> </thead> <tbody> <tr> <td>GPT-OSS 120B</td> <td>normal</td> <td>77.8%</td> <td>81.7%</td> </tr> <tr> <td>GPT-OSS 120B</td> <td>hard</td> <td>74.0%</td> <td>79.3%</td> </tr> <tr> <td>GPT-OSS 120B</td> <td>extra_hard</td> <td>49.7%</td> <td>66.4%</td> </tr> <tr> <td>Llama 3.3 70B</td> <td>normal</td> <td>61.0%</td> <td>63.1%</td> </tr> <tr> <td>Llama 3.3 70B</td> <td>hard</td> <td>56.5%</td> <td>61.6%</td> </tr> <tr> <td>Llama 3.3 70B</td> <td>extra_hard</td> <td>31.7%</td> <td>41.1%</td> </tr> <tr> <td>Gemma 4 31B</td> <td>normal</td> <td>52.0%</td> <td>13.3%</td> </tr> <tr> <td>Gemma 4 31B</td> <td>hard</td> <td>51.0%</td> <td>16.5%</td> </tr> <tr> <td>Gemma 4 31B</td> <td>extra_hard</td> <td><strong>65.8%</strong></td> <td>48.4%</td> </tr> </tbody> </table> <p>Aggregate: <strong>57.7% acc / 52.4% F1 → rank 85</strong>.</p> <pre><code class="language-plotly">{"data":[{"x":["normal","hard","extra_hard"],"y":[77.8,74.0,49.7],"name":"GPT-OSS 120B","type":"bar","marker":{"color":"#636efa"},"hovertemplate":"GPT-OSS 120B&lt;br&gt;%{x}: %{y}%&lt;extra&gt;&lt;/extra&gt;"},{"x":["normal","hard","extra_hard"],"y":[61.0,56.5,31.7],"name":"Llama 3.3 70B","type":"bar","marker":{"color":"#EF553B"},"hovertemplate":"Llama 3.3 70B&lt;br&gt;%{x}: %{y}%&lt;extra&gt;&lt;/extra&gt;"},{"x":["normal","hard","extra_hard"],"y":[52.0,51.0,65.8],"name":"Gemma 4 31B","type":"bar","marker":{"color":"#00cc96"},"hovertemplate":"Gemma 4 31B&lt;br&gt;%{x}: %{y}%&lt;extra&gt;&lt;/extra&gt;"}],"layout":{"title":{"text":"Accuracy by model and difficulty"},"barmode":"group","yaxis":{"title":"accuracy (%)","range":[0,100]},"xaxis":{"title":"difficulty"},"height":420,"margin":{"l":60,"r":30,"t":60,"b":50},"legend":{"orientation":"h","x":0.1,"y":-0.15}}}
</code></pre> <p>The crossover on the extra_hard band is the result the rest of the post is about. Gemma (the smallest model) lands above GPT-OSS (the largest) for the first time.</p> <h3 id="gpt-oss-only-leaderboard--rank-13">GPT-OSS-only leaderboard — rank 13</h3> <p>Restricting the same submission to the 120B model: <strong>67.2% acc / 75.8% F1</strong>.</p> <h3 id="order-5-research-subset-gpt-oss--rank-21">Order-5 research subset (GPT-OSS) — rank 21</h3> <p>Order-5 equations are deeper trees and form a separate research-tier leaderboard. <strong>79.8% acc / 83.2% F1</strong>.</p> <h3 id="the-gemma-extra-hard-anomaly--rank-20">The Gemma extra-hard anomaly — rank 20</h3> <p>On the extra-hard set specifically, the per-model ranks tell a different story:</p> <table> <thead> <tr> <th>Model</th> <th>Params</th> <th>Extra-hard accuracy</th> <th>Rank on extra-hard</th> </tr> </thead> <tbody> <tr> <td>Gemma 4 31B</td> <td>31B</td> <td>65.8%</td> <td><strong>20</strong></td> </tr> <tr> <td>GPT-OSS 120B</td> <td>120B</td> <td>49.7%</td> <td>157</td> </tr> <tr> <td>Llama 3.3 70B</td> <td>70B</td> <td>31.7%</td> <td>278</td> </tr> </tbody> </table> <p>Gemma 4 31B — the smallest model in the eval — got the <em>best</em> extra-hard accuracy with this prompt, by a margin (16 points over GPT-OSS, 34 points over Llama). On a model 4× smaller than GPT-OSS.</p> <p>The likely explanation: extra-hard problems benefit most from structured procedure-following. GPT-OSS at 120B has more “smart enough to deviate” headroom — it interprets the cheatsheet, decides parts of it are unnecessary, and falls back to free-form reasoning that fails on the hardest cases. Gemma at 31B has less headroom for that kind of agency. It follows the procedure step by step because that’s what fits in its working memory, and the procedure is sound. On the easier sets where GPT-OSS’s looser interpretation usually gets the right answer anyway, the gap is reversed.</p> <p>If true, this is a real prompt-engineering result: highly-structured, low-freedom prompts may <em>prefer</em> smaller models on the hardest problems, because the largest models will second-guess the structure and lose.</p> <h3 id="f1-is-low-on-gemma--why">F1 is low on Gemma — why</h3> <p>Gemma’s F1 on normal/hard is low (13.3%, 16.5%) despite reasonable accuracy. This is the asymmetric default at work: Gemma takes the “if uncertain, return true” instruction very literally and answers <code class="language-plaintext highlighter-rouge">true</code> on most edge cases. Accuracy stays okay because the base rate of <code class="language-plaintext highlighter-rouge">true</code> in the dataset is high. F1 collapses because Gemma is producing few <code class="language-plaintext highlighter-rouge">false</code> answers, so the precision/recall on the <code class="language-plaintext highlighter-rouge">false</code> class is bad. The same instruction that <em>fixes</em> GPT-OSS’s hallucinated refutations <em>over-fixes</em> Gemma’s. This is the cost of a single uniform prompt across very different models.</p> <h3 id="sanity-check-qwen-25-7b-locally">Sanity check: Qwen 2.5 7B locally</h3> <p>After the submission I ran the same cheatsheet against a <em>much</em> smaller model — Qwen 2.5 7B, 4-bit quantized, running locally on a 4 GB GTX 1650 via Ollama — over the first 50 problems of <code class="language-plaintext highlighter-rouge">hard3</code> (the local dataset roughly aligned with the competition’s extra-hard band). Same prompt, same temperature 0, same parse rule (take the last <code class="language-plaintext highlighter-rouge">true</code>/<code class="language-plaintext highlighter-rouge">false</code> token).</p> <table> <thead> <tr> <th>Model</th> <th>Params</th> <th>Quant</th> <th>Set</th> <th>N</th> <th>Accuracy</th> <th>F1 (true)</th> <th>Precision</th> <th>Recall</th> </tr> </thead> <tbody> <tr> <td>Qwen 2.5 7B</td> <td>7B</td> <td>4-bit</td> <td>hard3 (first 50)</td> <td>50</td> <td>56.0%</td> <td>66.7%</td> <td>53.8%</td> <td>87.5%</td> </tr> <tr> <td>Gemma 4 31B</td> <td>31B</td> <td>fp16</td> <td>extra_hard</td> <td>500</td> <td>65.8%</td> <td>48.4%</td> <td>—</td> <td>—</td> </tr> <tr> <td>GPT-OSS 120B</td> <td>120B</td> <td>—</td> <td>extra_hard</td> <td>500</td> <td>49.7%</td> <td>66.4%</td> <td>—</td> <td>—</td> </tr> <tr> <td>Llama 3.3 70B</td> <td>70B</td> <td>—</td> <td>extra_hard</td> <td>500</td> <td>31.7%</td> <td>41.1%</td> <td>—</td> <td>—</td> </tr> </tbody> </table> <p>Confusion matrix on Qwen: TP=21, TN=7, FP=18, FN=3 — recall on <code class="language-plaintext highlighter-rouge">true</code> is 87.5%, precision is 53.8%. Same lopsided pattern as Gemma. The model defaults to <code class="language-plaintext highlighter-rouge">true</code> and is dragged down by the false-positive count on actual-<code class="language-plaintext highlighter-rouge">false</code> problems.</p> <p>50 problems is a small slice and the bands aren’t a perfect match. Even so: a 7B model at 4-bit, running on a single laptop GPU, comes in <em>above</em> GPT-OSS 120B and Llama 3.3 70B on the hardest band. The cheatsheet is doing more of the work than the model is.</p> <p>A representative failure on Qwen (problem 1 of hard3):</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>E1:  x = x * (x * y)
E2:  x = (x * ((x * x) * x)) * x
actual: true     pred: false
</code></pre></div></div> <p>The model’s trace claims <code class="language-plaintext highlighter-rouge">sig(E1)[2] = T</code> for the Constant magma. The rule says <code class="language-plaintext highlighter-rouge">depth(L) ≥ 1 ∧ depth(R) ≥ 1</code> or <code class="language-plaintext highlighter-rouge">L, R same bare var</code>. Here <code class="language-plaintext highlighter-rouge">L = x</code> (a leaf — depth 0), so the first conjunct fails; they’re not the same bare variable either. <code class="language-plaintext highlighter-rouge">h2</code> should be <code class="language-plaintext highlighter-rouge">F</code>, the model wrote <code class="language-plaintext highlighter-rouge">T</code>, and a non-existent refutation gets reported as a real one.</p> <p>This is the exact category the cheatsheet warns about: a tree-walking error on the depth primitive, propagating into a wrong bit, propagating into a hallucinated refutation. The fix for a future version is more aggressive worked examples on the Constant rule — the same defensive treatment magmas 6, 7, 8 already get. The PARSE step probably also needs to force the model to emit <code class="language-plaintext highlighter-rouge">depth(L)</code> and <code class="language-plaintext highlighter-rouge">depth(R)</code> as explicit lines before any rule fires.</p> <h2 id="what-id-change">What I’d change</h2> <ul> <li><strong>Per-model branches.</strong> The Gemma F1 collapse is fixable: weaken the refutation discipline slightly for smaller models so they produce <code class="language-plaintext highlighter-rouge">false</code> more often, while keeping it strict for the larger ones. The competition’s “one prompt, three models” rule made this not an option for the submission, but it’s the obvious next experiment.</li> <li><strong>More magmas.</strong> Extending the catalog past 9 is mostly mechanical — every magma you add comes with its closed-form predicate, plug it into <code class="language-plaintext highlighter-rouge">sig</code>. The hard part is finding magmas that <em>actually</em> refute on the test distribution. An earlier draft had a 4-element bit-swap magma <code class="language-plaintext highlighter-rouge">a*b = 2(a mod 2) + floor(b/2)</code> (the “C4” magma) which characterized algebraically via slot-pairs; it added measurable coverage on the harder sets in offline testing but I cut it from the submission to keep the cheatsheet at the size where smaller models still parse it reliably.</li> <li><strong>Re-run the PARSE step under model-specific delimiters.</strong> Different model families parse code blocks and bullet lists with slightly different reliability. The PARSE step is the single most important determinant of accuracy on the hard set; getting that step to fire correctly for Gemma vs Llama vs GPT-OSS is worth more than any new magma.</li> </ul> <h2 id="apply-the-procedure-to-this-instance">Apply the procedure to this instance</h2> <p>The submission file ends with the actual placeholder block the harness substitutes into. The full cheatsheet, the Qwen 2.5 7B sanity-check results, and the verification scripts are at <a href="https://github.com/debtirthasaha/equational-theories-cheatsheet">github.com/debtirthasaha/equational-theories-cheatsheet</a>.</p>]]></content><author><name></name></author><category term="nlp"/><category term="prompting"/><category term="llm-reasoning"/><category term="benchmarks"/><category term="equational-logic"/><summary type="html"><![CDATA[An 8.71 KB prompt for SAIR's equational-theories competition (Tao + Davis, follow-up to Honda-Murakami-Zhang 2025). Replace free-form LLM reasoning with a 9-magma Birkhoff-sound decision procedure. A 31B model running this prompt beat a 120B one on the hardest set.]]></summary></entry><entry><title type="html">Tiny Shakespeare, tiny GPT</title><link href="https://debtirthasaha.github.io/blog/2026/tiny-shakespeare-tiny-gpt/" rel="alternate" type="text/html" title="Tiny Shakespeare, tiny GPT"/><published>2026-04-15T10:00:00+00:00</published><updated>2026-04-15T10:00:00+00:00</updated><id>https://debtirthasaha.github.io/blog/2026/tiny-shakespeare-tiny-gpt</id><content type="html" xml:base="https://debtirthasaha.github.io/blog/2026/tiny-shakespeare-tiny-gpt/"><![CDATA[<p>Same architecture as GPT-2, scaled to fit a 4 GB GPU and trained on 1 MB of Shakespeare. Built one mechanism at a time, measuring val loss after each addition.</p> <table> <thead> <tr> <th>After adding</th> <th>Val loss</th> </tr> </thead> <tbody> <tr> <td>Bigram baseline</td> <td>2.88</td> </tr> <tr> <td>+ single-head self-attention</td> <td>2.41</td> </tr> <tr> <td>+ multi-head</td> <td>2.32</td> </tr> <tr> <td>+ feed-forward MLP</td> <td>2.23</td> </tr> <tr> <td>+ residual + LayerNorm (3 blocks)</td> <td>2.09</td> </tr> <tr> <td>Scaled up (4 layers, 192-d, 6 heads)</td> <td><strong>1.59</strong></td> </tr> </tbody> </table> <pre><code class="language-plotly">{"data":[{"x":["bigram","+ single-head attn","+ multi-head","+ feed-forward","+ residual + LN (x3)","scaled up"],"y":[2.88,2.41,2.32,2.23,2.09,1.59],"type":"bar","marker":{"color":["#bbbbbb","#9b9bff","#7a7aff","#5959ff","#3838ff","#EF553B"]},"text":[2.88,2.41,2.32,2.23,2.09,1.59],"textposition":"outside","hovertemplate":"%{x}&lt;br&gt;val loss %{y}&lt;extra&gt;&lt;/extra&gt;"}],"layout":{"title":{"text":"Validation loss after each architectural addition"},"yaxis":{"title":"val loss","range":[0,3.2]},"xaxis":{"tickangle":-25},"height":420,"margin":{"l":60,"r":30,"t":60,"b":120},"showlegend":false}}
</code></pre> <p>1.83M parameters, val loss 1.59. Output has speaker tags, sentence rhythm, and (mostly) closed quotes.</p> <h2 id="data-in-one-tensor">Data, in one tensor</h2> <p>Dataset is <code class="language-plaintext highlighter-rouge">input.txt</code> — every Shakespeare play, 1,115,394 characters, 65 unique characters including <code class="language-plaintext highlighter-rouge">\n</code>.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">chars</span> <span class="o">=</span> <span class="nf">sorted</span><span class="p">(</span><span class="nf">list</span><span class="p">(</span><span class="nf">set</span><span class="p">(</span><span class="n">text</span><span class="p">)))</span>
<span class="n">stoi</span>  <span class="o">=</span> <span class="p">{</span><span class="n">ch</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">ch</span> <span class="ow">in</span> <span class="nf">enumerate</span><span class="p">(</span><span class="n">chars</span><span class="p">)}</span>
<span class="n">itos</span>  <span class="o">=</span> <span class="p">{</span><span class="n">i</span><span class="p">:</span> <span class="n">ch</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">ch</span> <span class="ow">in</span> <span class="nf">enumerate</span><span class="p">(</span><span class="n">chars</span><span class="p">)}</span>
<span class="n">encode</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">s</span><span class="p">:</span> <span class="p">[</span><span class="n">stoi</span><span class="p">[</span><span class="n">c</span><span class="p">]</span> <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="n">s</span><span class="p">]</span>
<span class="n">decode</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">l</span><span class="p">:</span> <span class="sh">''</span><span class="p">.</span><span class="nf">join</span><span class="p">([</span><span class="n">itos</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">l</span><span class="p">])</span>

<span class="n">data</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">tensor</span><span class="p">(</span><span class="nf">encode</span><span class="p">(</span><span class="n">text</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">long</span><span class="p">)</span>
<span class="n">n</span> <span class="o">=</span> <span class="nf">int</span><span class="p">(</span><span class="mf">0.9</span> <span class="o">*</span> <span class="nf">len</span><span class="p">(</span><span class="n">data</span><span class="p">))</span>
<span class="n">train_data</span> <span class="o">=</span> <span class="n">data</span><span class="p">[:</span><span class="n">n</span><span class="p">]</span>   <span class="c1"># 1,003,854 tokens
</span><span class="n">val_data</span>   <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="n">n</span><span class="p">:]</span>   <span class="c1">#   111,540 tokens
</span></code></pre></div></div> <p>That’s the tokenizer. (A real BPE tokenizer comes in the <a href="/blog/2026/bpe-tokenizer/">next post</a>.)</p> <h2 id="sampling-chunks-not-whole-documents">Sampling chunks, not whole documents</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_batch</span><span class="p">(</span><span class="n">split</span><span class="p">):</span>
    <span class="n">data</span> <span class="o">=</span> <span class="n">train_data</span> <span class="k">if</span> <span class="n">split</span> <span class="o">==</span> <span class="sh">'</span><span class="s">train</span><span class="sh">'</span> <span class="k">else</span> <span class="n">val_data</span>
    <span class="n">ix</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">randint</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="o">-</span> <span class="n">block_size</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,))</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">stack</span><span class="p">([</span><span class="n">data</span><span class="p">[</span><span class="n">i</span>    <span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">block_size</span>    <span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">ix</span><span class="p">])</span>
    <span class="n">y</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">stack</span><span class="p">([</span><span class="n">data</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">block_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">ix</span><span class="p">])</span>
    <span class="k">return</span> <span class="n">x</span><span class="p">.</span><span class="nf">to</span><span class="p">(</span><span class="n">device</span><span class="p">),</span> <span class="n">y</span><span class="p">.</span><span class="nf">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">y</code> is <code class="language-plaintext highlighter-rouge">x</code> shifted right by one. Each position in the chunk is one training example: the model predicts position <code class="language-plaintext highlighter-rouge">t+1</code> given positions <code class="language-plaintext highlighter-rouge">0..t</code>. A chunk of length 8 produces 8 examples in parallel. Stacking <code class="language-plaintext highlighter-rouge">batch_size</code> chunks gives independent training signal — sequences in a batch don’t communicate.</p> <h2 id="the-mathematical-trick">The mathematical trick</h2> <p>Before self-attention there’s one observation worth dwelling on: how do you let each position aggregate information from all previous positions in <em>parallel</em>?</p> <p>Goal: <code class="language-plaintext highlighter-rouge">xbow[b, t] = mean(x[b, 0..t])</code>. Naive double loop. The fast version is a matrix multiply.</p> <p>Build a lower-triangular matrix of equal weights:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>wei = [[1.00, 0.00, 0.00, 0.00],
       [0.50, 0.50, 0.00, 0.00],
       [0.33, 0.33, 0.33, 0.00],
       [0.25, 0.25, 0.25, 0.25]]
</code></pre></div></div> <p>Now <code class="language-plaintext highlighter-rouge">wei @ x</code> produces, for each row of <code class="language-plaintext highlighter-rouge">wei</code>, a weighted sum over the value vectors. Row <code class="language-plaintext highlighter-rouge">t</code> only mixes positions <code class="language-plaintext highlighter-rouge">0..t</code>. Same answer as the loop, one matmul.</p> <p>Self-attention will turn <code class="language-plaintext highlighter-rouge">wei</code> from “equal weights” into “weights computed from the data.” The matrix-multiply-with-causal-mask structure stays.</p> <h2 id="self-attention-single-head">Self-attention, single head</h2> <p>Three linear projections from the residual stream to a smaller <code class="language-plaintext highlighter-rouge">head_size</code>:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Head</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">head_size</span><span class="p">):</span>
        <span class="nf">super</span><span class="p">().</span><span class="nf">__init__</span><span class="p">()</span>
        <span class="n">self</span><span class="p">.</span><span class="n">key</span>   <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="n">head_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="nf">register_buffer</span><span class="p">(</span><span class="sh">'</span><span class="s">tril</span><span class="sh">'</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="nf">tril</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="nf">ones</span><span class="p">(</span><span class="n">block_size</span><span class="p">,</span> <span class="n">block_size</span><span class="p">)))</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span>
        <span class="n">k</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">key</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">q</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">query</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

        <span class="n">wei</span> <span class="o">=</span> <span class="n">q</span> <span class="o">@</span> <span class="n">k</span><span class="p">.</span><span class="nf">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">k</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">**-</span><span class="mf">0.5</span>   <span class="c1"># (B, T, T)
</span>        <span class="n">wei</span> <span class="o">=</span> <span class="n">wei</span><span class="p">.</span><span class="nf">masked_fill</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">tril</span><span class="p">[:</span><span class="n">T</span><span class="p">,</span> <span class="p">:</span><span class="n">T</span><span class="p">]</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nf">float</span><span class="p">(</span><span class="sh">'</span><span class="s">-inf</span><span class="sh">'</span><span class="p">))</span>
        <span class="n">wei</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="nf">softmax</span><span class="p">(</span><span class="n">wei</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>

        <span class="n">v</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">value</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">wei</span> <span class="o">@</span> <span class="n">v</span>
</code></pre></div></div> <p>Three details worth pointing at:</p> <ul> <li><code class="language-plaintext highlighter-rouge">1/sqrt(head_size)</code> scaling. Without it, large <code class="language-plaintext highlighter-rouge">head_size</code> produces dot products with large variance. Softmax collapses to near-one-hot and gradients stop flowing. Scaling holds the softmax diffuse at init.</li> <li><code class="language-plaintext highlighter-rouge">register_buffer</code> for <code class="language-plaintext highlighter-rouge">tril</code>. The mask is a constant; we don’t want it tracked as a learnable parameter, but we <em>do</em> want it to move to GPU when <code class="language-plaintext highlighter-rouge">model.to(device)</code> is called.</li> <li>Mask is <code class="language-plaintext highlighter-rouge">−inf</code> on upper triangle. After softmax, <code class="language-plaintext highlighter-rouge">exp(−inf) = 0</code>. Future positions get exactly zero weight.</li> </ul> <p>Single head added on top of the bigram baseline drops val loss 2.88 → 2.41.</p> <h2 id="multi-head-attention">Multi-head attention</h2> <p>Run <code class="language-plaintext highlighter-rouge">n_head</code> heads in parallel, concatenate, then project back to <code class="language-plaintext highlighter-rouge">n_embd</code>:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MultiHeadAttention</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">head_size</span><span class="p">):</span>
        <span class="nf">super</span><span class="p">().</span><span class="nf">__init__</span><span class="p">()</span>
        <span class="n">self</span><span class="p">.</span><span class="n">heads</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">ModuleList</span><span class="p">([</span><span class="nc">Head</span><span class="p">(</span><span class="n">head_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="n">num_heads</span><span class="p">)])</span>
        <span class="n">self</span><span class="p">.</span><span class="n">proj</span>  <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">head_size</span> <span class="o">*</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">cat</span><span class="p">([</span><span class="nf">h</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">self</span><span class="p">.</span><span class="n">heads</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">self</span><span class="p">.</span><span class="nf">proj</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">head_size = n_embd // n_head</code>, so concatenation gets you back to <code class="language-plaintext highlighter-rouge">n_embd</code>. The final <code class="language-plaintext highlighter-rouge">proj</code> is what lets heads interact — concatenation alone just glues outputs together.</p> <p><code class="language-plaintext highlighter-rouge">nn.ModuleList</code> instead of a plain Python list: PyTorch needs to know about these submodules to register their parameters. A plain list is invisible to <code class="language-plaintext highlighter-rouge">model.parameters()</code>.</p> <p>Val loss: 2.32.</p> <h2 id="feed-forward-per-token-computation">Feed-forward: per-token computation</h2> <p>After attention mixes information across positions, the FFN lets each position <em>think</em> about what it gathered. Same MLP applied to every token independently:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">self</span><span class="p">.</span><span class="n">net</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">Sequential</span><span class="p">(</span>
    <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="n">n_embd</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">n_embd</span><span class="p">),</span>
    <span class="n">nn</span><span class="p">.</span><span class="nc">ReLU</span><span class="p">(),</span>
    <span class="n">nn</span><span class="p">.</span><span class="nc">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">n_embd</span><span class="p">,</span> <span class="n">n_embd</span><span class="p">),</span>
<span class="p">)</span>
</code></pre></div></div> <p>4× expansion is from the original Transformer paper. Val loss: 2.23.</p> <p>A useful mental model from this point on: <strong>attention is communication, feed-forward is computation.</strong> Tokens talk to each other in attention, then each one updates its own representation in the FFN.</p> <h2 id="the-block-pre-norm--residual">The block: pre-norm + residual</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="nf">sa</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="nf">ln1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="nf">ffwd</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="nf">ln2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
        <span class="k">return</span> <span class="n">x</span>
</code></pre></div></div> <p>Three things going on:</p> <ul> <li><strong>Residual <code class="language-plaintext highlighter-rouge">x + ...</code></strong>: the addition creates a gradient highway. Gradients flow back through <code class="language-plaintext highlighter-rouge">+</code> undisturbed to earlier layers. Without it, deep networks have vanishing gradients.</li> <li><strong>LayerNorm before the sub-block (pre-norm)</strong>: the original 2017 paper put LN <em>after</em>; modern practice puts it before. Pre-norm trains more stably at depth.</li> <li><strong>Sub-blocks start near-zero in output</strong>: at init, attention and FFN contribute tiny perturbations to the residual stream. Useful signal accumulates gradually.</li> </ul> <p>Val loss after stacking 3 such blocks: 2.09.</p> <h2 id="scaling-up">Scaling up</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">batch_size</span>    <span class="o">=</span> <span class="mi">64</span>
<span class="n">block_size</span>    <span class="o">=</span> <span class="mi">128</span>
<span class="n">n_embd</span>        <span class="o">=</span> <span class="mi">192</span>
<span class="n">n_head</span>        <span class="o">=</span> <span class="mi">6</span>      <span class="c1"># head_size = 192/6 = 32
</span><span class="n">n_layer</span>       <span class="o">=</span> <span class="mi">4</span>
<span class="n">dropout</span>       <span class="o">=</span> <span class="mf">0.2</span>
</code></pre></div></div> <p>1.83M parameters. Val loss <strong>1.59</strong>. The drop from 2.09 → 1.59 is mostly capacity — same architecture, just more of it.</p> <p>Sample output:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>DUKE VINCENTIO:
Whither dost thou pursue, and what shall be done
To these things that he was forc'd to make?

ANGELO:
My lord, I will entreat your grace's hand.
</code></pre></div></div> <p>The model is making it up character by character. There is no concept of a word in its vocabulary. It learned word boundaries, speaker tags, and sentence structure from 1 MB of text.</p> <h2 id="whats-missing-vs-gpt-2">What’s missing vs GPT-2</h2> <p>This is GPT-2’s architecture at small scale. The pieces <em>not</em> in this build:</p> <ul> <li>A real tokenizer (we used per-character; GPT-2 uses BPE — <a href="/blog/2026/bpe-tokenizer/">post</a>).</li> <li>Weight tying between <code class="language-plaintext highlighter-rouge">wte</code> and <code class="language-plaintext highlighter-rouge">lm_head</code> (the input embedding and output classifier share the same matrix in GPT-2 — <code class="language-plaintext highlighter-rouge">lm_head.weight = wte.weight</code>).</li> <li>Initialization variants (GPT-2 scales the output projection of each residual layer by <code class="language-plaintext highlighter-rouge">1/sqrt(2*n_layer)</code> to control variance through depth).</li> <li>A proper optimizer recipe (cosine LR schedule, weight decay split, warmup) and DDP for multi-GPU.</li> </ul> <p>All of those show up in the <a href="/blog/2026/gpt2-124m/">GPT-2 reproduction</a>.</p> <p>Code: <a href="https://github.com/debtirthasaha/tiny-gpt-shakespeare">github.com/debtirthasaha/tiny-gpt-shakespeare</a>.</p>]]></content><author><name></name></author><category term="deep-learning"/><category term="transformer"/><category term="attention"/><category term="gpt"/><summary type="html"><![CDATA[A 1.83M-parameter decoder-only transformer trained on 1MB of Shakespeare. Architecture is identical to GPT-2, just smaller.]]></summary></entry><entry><title type="html">makemore: from counting bigrams to a WaveNet</title><link href="https://debtirthasaha.github.io/blog/2026/makemore/" rel="alternate" type="text/html" title="makemore: from counting bigrams to a WaveNet"/><published>2026-04-08T10:00:00+00:00</published><updated>2026-04-08T10:00:00+00:00</updated><id>https://debtirthasaha.github.io/blog/2026/makemore</id><content type="html" xml:base="https://debtirthasaha.github.io/blog/2026/makemore/"><![CDATA[<p><code class="language-plaintext highlighter-rouge">names.txt</code>: 32,033 names, one per line. Vocabulary is 26 letters + <code class="language-plaintext highlighter-rouge">.</code> (start/end token) = 27 characters. Every name <code class="language-plaintext highlighter-rouge">emma</code> is wrapped to <code class="language-plaintext highlighter-rouge">.emma.</code> and the bigrams are <code class="language-plaintext highlighter-rouge">(.,e), (e,m), (m,m), (m,a), (a,.)</code>. Goal: predict the next character.</p> <p>Five models, each adding one mechanism. Loss is negative log likelihood, lower is better.</p> <table> <thead> <tr> <th>Model</th> <th>Mechanism</th> <th>Val NLL</th> </tr> </thead> <tbody> <tr> <td>Bigram counts</td> <td>27×27 count matrix, +1 smoothing</td> <td>2.45</td> </tr> <tr> <td>Bigram NN</td> <td>27→27 logits, softmax, gradient descent</td> <td>2.46</td> </tr> <tr> <td>MLP (Bengio 2003)</td> <td>3-char context, 10-dim embedding, 200-hidden tanh</td> <td>2.10</td> </tr> <tr> <td>MLP + BN + Kaiming</td> <td>same + proper init + batch norm</td> <td>2.05</td> </tr> <tr> <td>WaveNet-style</td> <td>hierarchical pairwise fusion, 8-char context</td> <td>1.99</td> </tr> </tbody> </table> <pre><code class="language-plotly">{"data":[{"x":["bigram counts","bigram NN","MLP","MLP + BN + Kaiming","WaveNet-style"],"y":[2.45,2.46,2.10,2.05,1.99],"type":"bar","marker":{"color":["#bbbbbb","#9b9bff","#5959ff","#3838ff","#EF553B"]},"text":[2.45,2.46,2.10,2.05,1.99],"textposition":"outside","hovertemplate":"%{x}&lt;br&gt;val NLL %{y}&lt;extra&gt;&lt;/extra&gt;"}],"layout":{"title":{"text":"Validation NLL across the five models"},"yaxis":{"title":"NLL (lower = better)","range":[0,3.0]},"xaxis":{"tickangle":-25},"height":420,"margin":{"l":60,"r":30,"t":60,"b":120},"showlegend":false}}
</code></pre> <h2 id="1-counting-bigrams">1. Counting bigrams</h2> <p>Build the count matrix directly:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">N</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">zeros</span><span class="p">((</span><span class="mi">27</span><span class="p">,</span> <span class="mi">27</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">int32</span><span class="p">)</span>
<span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">words</span><span class="p">:</span>
    <span class="n">chs</span> <span class="o">=</span> <span class="p">[</span><span class="sh">'</span><span class="s">.</span><span class="sh">'</span><span class="p">]</span> <span class="o">+</span> <span class="nf">list</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="p">[</span><span class="sh">'</span><span class="s">.</span><span class="sh">'</span><span class="p">]</span>
    <span class="k">for</span> <span class="n">ch1</span><span class="p">,</span> <span class="n">ch2</span> <span class="ow">in</span> <span class="nf">zip</span><span class="p">(</span><span class="n">chs</span><span class="p">,</span> <span class="n">chs</span><span class="p">[</span><span class="mi">1</span><span class="p">:]):</span>
        <span class="n">N</span><span class="p">[</span><span class="n">stoi</span><span class="p">[</span><span class="n">ch1</span><span class="p">],</span> <span class="n">stoi</span><span class="p">[</span><span class="n">ch2</span><span class="p">]]</span> <span class="o">+=</span> <span class="mi">1</span>

<span class="n">P</span> <span class="o">=</span> <span class="p">(</span><span class="n">N</span> <span class="o">+</span> <span class="mi">1</span><span class="p">).</span><span class="nf">float</span><span class="p">()</span>
<span class="n">P</span> <span class="o">/=</span> <span class="n">P</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">+1</code> smoothing avoids <code class="language-plaintext highlighter-rouge">log(0)</code> on bigrams that never appeared in training.</p> <p>Sampling is <code class="language-plaintext highlighter-rouge">torch.multinomial(P[ix], num_samples=1)</code> in a loop until you draw the <code class="language-plaintext highlighter-rouge">.</code> token.</p> <p>NLL = <code class="language-plaintext highlighter-rouge">−mean(log P[bigram])</code> over the training set = <strong>2.4543</strong>. As a sanity check: <code class="language-plaintext highlighter-rouge">exp(−2.45) ≈ 8.7%</code>, vs 1/27 ≈ 3.7% for uniform random. The bigram model assigns roughly 2.4× more probability to the correct next character than chance.</p> <h2 id="2-the-same-bigram-model-as-a-neural-net">2. The same bigram model as a neural net</h2> <p>Same model, found by gradient descent instead of counting:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">xenc</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="nf">one_hot</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">27</span><span class="p">).</span><span class="nf">float</span><span class="p">()</span>   <span class="c1"># (N, 27)
</span><span class="n">W</span>    <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">randn</span><span class="p">((</span><span class="mi">27</span><span class="p">,</span> <span class="mi">27</span><span class="p">),</span> <span class="n">requires_grad</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="n">logits</span> <span class="o">=</span> <span class="n">xenc</span> <span class="o">@</span> <span class="n">W</span>
<span class="n">counts</span> <span class="o">=</span> <span class="n">logits</span><span class="p">.</span><span class="nf">exp</span><span class="p">()</span>
<span class="n">probs</span>  <span class="o">=</span> <span class="n">counts</span> <span class="o">/</span> <span class="n">counts</span><span class="p">.</span><span class="nf">sum</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">loss</span>   <span class="o">=</span> <span class="o">-</span><span class="n">probs</span><span class="p">[</span><span class="n">torch</span><span class="p">.</span><span class="nf">arange</span><span class="p">(</span><span class="n">n</span><span class="p">),</span> <span class="n">ys</span><span class="p">].</span><span class="nf">log</span><span class="p">().</span><span class="nf">mean</span><span class="p">()</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">xenc @ W</code> is a row lookup (one-hot times matrix = pick a row of W). The “logits” are log-counts up to a constant. <code class="language-plaintext highlighter-rouge">softmax(logits)</code> matches the row-normalized count matrix. Trained 200 steps with <code class="language-plaintext highlighter-rouge">lr=50</code>, lands at NLL <strong>2.4576</strong> — within 0.01 of the count model.</p> <p>The takeaway: this is <em>the same model</em>, just parameterized differently. The neural net’s W converges to the log of the count matrix. The equivalence breaks the moment you add nonlinearity or more context.</p> <h2 id="3-mlp-bengio-2003">3. MLP, Bengio 2003</h2> <p>Bigrams are too local. With context <code class="language-plaintext highlighter-rouge">..</code> you can’t tell <code class="language-plaintext highlighter-rouge">e</code> from <code class="language-plaintext highlighter-rouge">o</code>; with context <code class="language-plaintext highlighter-rouge">..em</code> you can. Bump context from 1 → 3 characters.</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>input: 3 char indices, e.g. [0, 0, 5]
  → embedding C: (27, 10)             → (3, 10)
  → concatenate                       → (30,)
  → Linear(30, 200) + tanh
  → Linear(200, 27)                   → logits
  → softmax                           → next-char distribution
</code></pre></div></div> <p>Dataset built by sliding a 3-window over each name:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>context           target
['.', '.', '.']   'e'
['.', '.', 'e']   'm'
['.', 'e', 'm']   'm'
['e', 'm', 'm']   'a'
['m', 'm', 'a']   '.'
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">build_dataset()</code> returns <code class="language-plaintext highlighter-rouge">X (228146, 3)</code> and <code class="language-plaintext highlighter-rouge">Y (228146,)</code>. 80/10/10 train/dev/test split.</p> <p>Forward:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">emb</span> <span class="o">=</span> <span class="n">C</span><span class="p">[</span><span class="n">Xb</span><span class="p">]</span>                          <span class="c1"># (B, 3, 10)
</span><span class="n">h</span>   <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">tanh</span><span class="p">(</span><span class="n">emb</span><span class="p">.</span><span class="nf">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">30</span><span class="p">)</span> <span class="o">@</span> <span class="n">W1</span> <span class="o">+</span> <span class="n">b1</span><span class="p">)</span>   <span class="c1"># (B, 200)
</span><span class="n">logits</span> <span class="o">=</span> <span class="n">h</span> <span class="o">@</span> <span class="n">W2</span> <span class="o">+</span> <span class="n">b2</span>                  <span class="c1"># (B, 27)
</span><span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="nf">cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">Yb</span><span class="p">)</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">emb.view(-1, 30)</code> flattens the 3-char window into a 30-d vector. Same network sees position-dependent patterns because each character’s embedding occupies a different slice of the input.</p> <p>Trains in ~30 sec. Val loss ~2.10. Sampled names start sounding like names: <code class="language-plaintext highlighter-rouge">montelle</code>, <code class="language-plaintext highlighter-rouge">kymbry</code>, <code class="language-plaintext highlighter-rouge">madiet</code>.</p> <h2 id="4-the-three-init-bugs-nobody-tells-you-about">4. The three init bugs nobody tells you about</h2> <p>The MLP works, but if you instrument it, three things are quietly broken at step 0.</p> <p><strong>Initial loss is too high.</strong> Loss at random init is ~27 (exploded softmax). Expected value is <code class="language-plaintext highlighter-rouge">−log(1/27) ≈ 3.3</code>. Cause: <code class="language-plaintext highlighter-rouge">W2</code> and <code class="language-plaintext highlighter-rouge">b2</code> initialized from <code class="language-plaintext highlighter-rouge">N(0, 1)</code> produce logits with huge variance — softmax assigns near-1 probability to one random class, and if it’s not the right one, <code class="language-plaintext highlighter-rouge">−log(tiny) ≈ huge</code>. Fix: scale <code class="language-plaintext highlighter-rouge">W2</code> down by ~0.01 and zero <code class="language-plaintext highlighter-rouge">b2</code>. Initial loss drops to 3.32.</p> <p><strong>Tanh saturation.</strong> Most pre-activations land outside <code class="language-plaintext highlighter-rouge">[-2, 2]</code> at init, where <code class="language-plaintext highlighter-rouge">tanh</code> is flat. Local gradient <code class="language-plaintext highlighter-rouge">(1 − tanh²(x))</code> is near 0, gradients can’t flow through these neurons, and they’re effectively dead. Diagnose with <code class="language-plaintext highlighter-rouge">(h.abs() &gt; 0.99).float().mean()</code> per neuron — at init this is &gt;97% for some neurons. Fix: scale <code class="language-plaintext highlighter-rouge">W1</code> so that <code class="language-plaintext highlighter-rouge">(W1.T @ x)</code> has variance ~1.</p> <p><strong>Eyeballing the scaling factor.</strong> Kaiming He’s paper gives the formula directly: for a layer with <code class="language-plaintext highlighter-rouge">fan_in</code> inputs and a <code class="language-plaintext highlighter-rouge">tanh</code>/<code class="language-plaintext highlighter-rouge">relu</code> nonlinearity, initialize weights from <code class="language-plaintext highlighter-rouge">N(0, gain/sqrt(fan_in))</code> where <code class="language-plaintext highlighter-rouge">gain = 5/3</code> for tanh, <code class="language-plaintext highlighter-rouge">sqrt(2)</code> for relu. PyTorch ships this as <code class="language-plaintext highlighter-rouge">torch.nn.init.kaiming_normal_</code>.</p> <p>After Kaiming init: pre-activations stay in <code class="language-plaintext highlighter-rouge">[-2, 2]</code>, no dead neurons, loss starts where it should. Val loss drops from 2.10 to ~2.07 just from fixing initialization.</p> <h2 id="5-batchnorm-forcing-the-distribution-post-hoc">5. BatchNorm: forcing the distribution post-hoc</h2> <p>Kaiming gets you into the right range <em>at init</em>. As you train, weights drift, distributions shift again. BatchNorm normalizes the pre-activation distribution <em>every forward pass</em>:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">bnmeani</span> <span class="o">=</span> <span class="n">hpreact</span><span class="p">.</span><span class="nf">mean</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">bnstdi</span>  <span class="o">=</span> <span class="n">hpreact</span><span class="p">.</span><span class="nf">std</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">keepdim</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">hpreact</span> <span class="o">=</span> <span class="n">bngain</span> <span class="o">*</span> <span class="p">(</span><span class="n">hpreact</span> <span class="o">-</span> <span class="n">bnmeani</span><span class="p">)</span> <span class="o">/</span> <span class="n">bnstdi</span> <span class="o">+</span> <span class="n">bnbias</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">bngain</code> and <code class="language-plaintext highlighter-rouge">bnbias</code> are learnable; they let the network un-do the normalization if it wants. In practice they stay small — the network mostly wants the normalized version.</p> <p>The annoying part is <em>inference</em>. At inference there is no batch — you might be predicting one example at a time. So BatchNorm keeps an exponential moving average of train-time batch statistics and uses those at eval. Two extra non-learnable buffers per BN layer. Train/eval modes diverge.</p> <p>This is also why <code class="language-plaintext highlighter-rouge">model.eval()</code> matters: without it, BatchNorm at inference would use the single-example statistics (variance = 0, division by zero, garbage output).</p> <p>Val loss with init fixes + BN: <strong>~2.05</strong>.</p> <h2 id="6-manual-backprop-every-gradient-by-hand">6. Manual backprop, every gradient by hand</h2> <p>For one block of training I deleted <code class="language-plaintext highlighter-rouge">loss.backward()</code> and computed every gradient by hand, layer by layer.</p> <p>The cross-entropy case is the one worth writing out. Cross-entropy fuses three ops: softmax, pick the correct-class probability, <code class="language-plaintext highlighter-rouge">−log</code>. Differentiating directly:</p> <p>For the correct class <code class="language-plaintext highlighter-rouge">y</code>:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>p_y = exp(z_y) / S        where S = Σ exp(z_j)
dL/dz_y = p_y − 1
</code></pre></div></div> <p>For any other class <code class="language-plaintext highlighter-rouge">i ≠ y</code>:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>dL/dz_i = p_i
</code></pre></div></div> <p>So <code class="language-plaintext highlighter-rouge">dlogits = probs.clone(); dlogits[range(n), y] -= 1; dlogits /= n</code>. That’s it. The most common loss function in deep learning has a 4-line gradient.</p> <p>Once you’ve done this, autograd stops being a black box. PyTorch is registering a <code class="language-plaintext highlighter-rouge">_backward</code> closure on each op, exactly like <a href="/blog/2026/building-micrograd/">micrograd</a>, then walking the DAG in reverse and applying these closed-form rules.</p> <h2 id="7-wavenet-style-hierarchical-fusion">7. WaveNet-style hierarchical fusion</h2> <p>The MLP smashes all 8 characters into one vector and runs a single Linear over it. Every character has to interact with every other character in one shot.</p> <p>WaveNet processes pairs of adjacent characters, then pairs of pairs, then pairs of those:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[c1 c2  c3 c4  c5 c6  c7 c8]   8 chars, 10-dim each
  \_/    \_/    \_/    \_/
  [b1    b2    b3    b4]        4 bigram reps
    \___/        \___/
    [q1          q2]              2 four-gram reps
       \________/
           [o1]                    1 output → predict next char
</code></pre></div></div> <p>Each fusion is the same operation: <code class="language-plaintext highlighter-rouge">Linear((B, T/2, 2C) → (B, T/2, C))</code> + <code class="language-plaintext highlighter-rouge">tanh</code>. Local context builds up gradually.</p> <p>Same dataset, same training loop. Val NLL: <strong>~1.99</strong>.</p> <h2 id="where-val-loss-can-keep-dropping">Where val loss can keep dropping</h2> <table> <thead> <tr> <th>Add</th> <th>Expected drop</th> </tr> </thead> <tbody> <tr> <td>Longer context (12, 16 chars)</td> <td>small, diminishing</td> </tr> <tr> <td>More embedding dims</td> <td>small</td> </tr> <tr> <td>Multi-head self-attention</td> <td>substantial — bigrams → attention is the biggest single step</td> </tr> <tr> <td>More data</td> <td>this dataset is tiny</td> </tr> </tbody> </table> <p>Attention is what the <a href="/blog/2026/tiny-shakespeare-tiny-gpt/">tiny GPT post</a> picks up.</p> <p>Code: <a href="https://github.com/debtirthasaha/makemore-from-scratch">github.com/debtirthasaha/makemore-from-scratch</a>.</p>]]></content><author><name></name></author><category term="deep-learning"/><category term="language-models"/><category term="mlp"/><category term="batchnorm"/><summary type="html"><![CDATA[Five character-level language models trained on 32K baby names. Bigram → MLP → BatchNorm → manual backprop → hierarchical fusion.]]></summary></entry><entry><title type="html">micrograd: a scalar-valued autograd engine</title><link href="https://debtirthasaha.github.io/blog/2026/building-micrograd/" rel="alternate" type="text/html" title="micrograd: a scalar-valued autograd engine"/><published>2026-04-01T10:00:00+00:00</published><updated>2026-04-01T10:00:00+00:00</updated><id>https://debtirthasaha.github.io/blog/2026/building-micrograd</id><content type="html" xml:base="https://debtirthasaha.github.io/blog/2026/building-micrograd/"><![CDATA[<p>A scalar-valued autograd engine. Every value is a Python <code class="language-plaintext highlighter-rouge">float</code> wrapped in a <code class="language-plaintext highlighter-rouge">Value</code> object that knows what produced it. Operator overloads build a DAG implicitly. <code class="language-plaintext highlighter-rouge">backward()</code> topologically sorts the DAG and runs each node’s local gradient rule in reverse. About 150 lines total.</p> <h2 id="the-value-class">The <code class="language-plaintext highlighter-rouge">Value</code> class</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Value</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">_children</span><span class="o">=</span><span class="p">(),</span> <span class="n">_op</span><span class="o">=</span><span class="sh">''</span><span class="p">):</span>
        <span class="n">self</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">data</span>
        <span class="n">self</span><span class="p">.</span><span class="n">grad</span> <span class="o">=</span> <span class="mf">0.0</span>
        <span class="n">self</span><span class="p">.</span><span class="n">_prev</span> <span class="o">=</span> <span class="nf">set</span><span class="p">(</span><span class="n">_children</span><span class="p">)</span>
        <span class="n">self</span><span class="p">.</span><span class="n">_op</span> <span class="o">=</span> <span class="n">_op</span>
        <span class="n">self</span><span class="p">.</span><span class="n">_backward</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="bp">None</span>
</code></pre></div></div> <ul> <li><code class="language-plaintext highlighter-rouge">data</code>: the forward scalar.</li> <li><code class="language-plaintext highlighter-rouge">grad</code>: filled in by backward, starts at 0.</li> <li><code class="language-plaintext highlighter-rouge">_prev</code>: parents in the DAG.</li> <li><code class="language-plaintext highlighter-rouge">_op</code>: string label (debugging only).</li> <li><code class="language-plaintext highlighter-rouge">_backward</code>: closure each operation sets; default no-op for leaves.</li> </ul> <h2 id="operator-overloads-register-local-gradient-rules">Operator overloads register local gradient rules</h2> <p>Addition:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">__add__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
    <span class="n">other</span> <span class="o">=</span> <span class="n">other</span> <span class="k">if</span> <span class="nf">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="n">Value</span><span class="p">)</span> <span class="k">else</span> <span class="nc">Value</span><span class="p">(</span><span class="n">other</span><span class="p">)</span>
    <span class="n">out</span> <span class="o">=</span> <span class="nc">Value</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">data</span> <span class="o">+</span> <span class="n">other</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">other</span><span class="p">),</span> <span class="sh">'</span><span class="s">+</span><span class="sh">'</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_backward</span><span class="p">():</span>
        <span class="n">self</span><span class="p">.</span><span class="n">grad</span>  <span class="o">+=</span> <span class="mf">1.0</span> <span class="o">*</span> <span class="n">out</span><span class="p">.</span><span class="n">grad</span>
        <span class="n">other</span><span class="p">.</span><span class="n">grad</span> <span class="o">+=</span> <span class="mf">1.0</span> <span class="o">*</span> <span class="n">out</span><span class="p">.</span><span class="n">grad</span>
    <span class="n">out</span><span class="p">.</span><span class="n">_backward</span> <span class="o">=</span> <span class="n">_backward</span>
    <span class="k">return</span> <span class="n">out</span>
</code></pre></div></div> <p>Three things at once: forward arithmetic, DAG construction (<code class="language-plaintext highlighter-rouge">(self, other)</code> as parents), and the local rule (<code class="language-plaintext highlighter-rouge">d(a+b)/da = 1</code>, <code class="language-plaintext highlighter-rouge">d(a+b)/db = 1</code>).</p> <p>Multiplication: same shape, different local rule.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">_backward</span><span class="p">():</span>
    <span class="n">self</span><span class="p">.</span><span class="n">grad</span>  <span class="o">+=</span> <span class="n">other</span><span class="p">.</span><span class="n">data</span> <span class="o">*</span> <span class="n">out</span><span class="p">.</span><span class="n">grad</span>
    <span class="n">other</span><span class="p">.</span><span class="n">grad</span> <span class="o">+=</span> <span class="n">self</span><span class="p">.</span><span class="n">data</span>  <span class="o">*</span> <span class="n">out</span><span class="p">.</span><span class="n">grad</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">tanh</code>:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">tanh</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
    <span class="n">t</span> <span class="o">=</span> <span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">self</span><span class="p">.</span><span class="n">data</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="n">self</span><span class="p">.</span><span class="n">data</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
    <span class="n">out</span> <span class="o">=</span> <span class="nc">Value</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="p">(</span><span class="n">self</span><span class="p">,),</span> <span class="sh">'</span><span class="s">tanh</span><span class="sh">'</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_backward</span><span class="p">():</span>
        <span class="n">self</span><span class="p">.</span><span class="n">grad</span> <span class="o">+=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">t</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">out</span><span class="p">.</span><span class="n">grad</span>
    <span class="n">out</span><span class="p">.</span><span class="n">_backward</span> <span class="o">=</span> <span class="n">_backward</span>
    <span class="k">return</span> <span class="n">out</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">**</code> and <code class="language-plaintext highlighter-rouge">exp</code> are the same pattern.</p> <h2 id="why--and-not-">Why <code class="language-plaintext highlighter-rouge">+=</code> and not <code class="language-plaintext highlighter-rouge">=</code></h2> <p>If a node feeds into multiple downstream nodes, the chain rule sums contributions over all paths. <code class="language-plaintext highlighter-rouge">+=</code> is exactly that sum.</p> <p>This is why <code class="language-plaintext highlighter-rouge">optimizer.zero_grad()</code> exists in PyTorch. Gradients accumulate by design; you have to clear them between training steps or you accumulate gradients across batches.</p> <h2 id="backward-via-topo-sort">Backward via topo-sort</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
    <span class="n">topo</span> <span class="o">=</span> <span class="p">[]</span>
    <span class="n">visited</span> <span class="o">=</span> <span class="nf">set</span><span class="p">()</span>
    <span class="k">def</span> <span class="nf">build_topo</span><span class="p">(</span><span class="n">v</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">v</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">visited</span><span class="p">:</span>
            <span class="n">visited</span><span class="p">.</span><span class="nf">add</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">child</span> <span class="ow">in</span> <span class="n">v</span><span class="p">.</span><span class="n">_prev</span><span class="p">:</span>
                <span class="nf">build_topo</span><span class="p">(</span><span class="n">child</span><span class="p">)</span>
            <span class="n">topo</span><span class="p">.</span><span class="nf">append</span><span class="p">(</span><span class="n">v</span><span class="p">)</span>
    <span class="nf">build_topo</span><span class="p">(</span><span class="n">self</span><span class="p">)</span>

    <span class="n">self</span><span class="p">.</span><span class="n">grad</span> <span class="o">=</span> <span class="mf">1.0</span>
    <span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="nf">reversed</span><span class="p">(</span><span class="n">topo</span><span class="p">):</span>
        <span class="n">v</span><span class="p">.</span><span class="nf">_backward</span><span class="p">()</span>
</code></pre></div></div> <p>Topo sort guarantees that when <code class="language-plaintext highlighter-rouge">v._backward()</code> runs, <code class="language-plaintext highlighter-rouge">v.grad</code> already has its final value — every downstream node has already pushed into it. Walking in reverse without the sort gives stale gradients.</p> <p>The base case <code class="language-plaintext highlighter-rouge">self.grad = 1.0</code> is the seed: the gradient of the final scalar with respect to itself is 1.</p> <h2 id="a-2-3-3-1-mlp-no-pytorch">A 2-3-3-1 MLP, no PyTorch</h2> <p>A <code class="language-plaintext highlighter-rouge">Neuron</code> is a list of weight <code class="language-plaintext highlighter-rouge">Value</code>s, a bias <code class="language-plaintext highlighter-rouge">Value</code>, and a <code class="language-plaintext highlighter-rouge">tanh</code>. A <code class="language-plaintext highlighter-rouge">Layer</code> is a list of <code class="language-plaintext highlighter-rouge">Neuron</code>s. An <code class="language-plaintext highlighter-rouge">MLP</code> is a list of <code class="language-plaintext highlighter-rouge">Layer</code>s.</p> <p>Training loop:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nf">range</span><span class="p">(</span><span class="mi">50</span><span class="p">):</span>
    <span class="n">ypred</span> <span class="o">=</span> <span class="p">[</span><span class="nf">n</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">xs</span><span class="p">]</span>
    <span class="n">loss</span>  <span class="o">=</span> <span class="nf">sum</span><span class="p">((</span><span class="n">yp</span> <span class="o">-</span> <span class="n">y</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span> <span class="k">for</span> <span class="n">yp</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="nf">zip</span><span class="p">(</span><span class="n">ypred</span><span class="p">,</span> <span class="n">ys</span><span class="p">))</span>

    <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">n</span><span class="p">.</span><span class="nf">parameters</span><span class="p">():</span>
        <span class="n">p</span><span class="p">.</span><span class="n">grad</span> <span class="o">=</span> <span class="mf">0.0</span>
    <span class="n">loss</span><span class="p">.</span><span class="nf">backward</span><span class="p">()</span>

    <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">n</span><span class="p">.</span><span class="nf">parameters</span><span class="p">():</span>
        <span class="n">p</span><span class="p">.</span><span class="n">data</span> <span class="o">-=</span> <span class="mf">0.05</span> <span class="o">*</span> <span class="n">p</span><span class="p">.</span><span class="n">grad</span>
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">(yp - y)**2</code>, <code class="language-plaintext highlighter-rouge">sum</code>, every neuron’s <code class="language-plaintext highlighter-rouge">tanh</code> — all of it constructs nodes in the same <code class="language-plaintext highlighter-rouge">Value</code> DAG. <code class="language-plaintext highlighter-rouge">loss.backward()</code> walks the whole thing.</p> <p>41 parameters. The actual training trajectory on the four-point demo, every step:</p> <pre><code class="language-plotly">{"data":[{"x":[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49],"y":[3.764166,3.514946,3.152909,2.638927,2.02037,1.407935,0.914896,0.591433,0.401572,0.290094,0.22114,0.175845,0.144458,0.121721,0.104642,0.091425,0.080942,0.072454,0.065461,0.059612,0.054657,0.050411,0.046738,0.043532,0.040712,0.038214,0.035988,0.033992,0.032194,0.030566,0.029087,0.027736,0.026498,0.025361,0.024312,0.023343,0.022443,0.021607,0.020828,0.020101,0.01942,0.018782,0.018182,0.017618,0.017086,0.016584,0.016109,0.015659,0.015233,0.014828],"type":"scatter","mode":"lines+markers","line":{"color":"#EF553B","width":2},"marker":{"size":4},"hovertemplate":"step %{x}&lt;br&gt;loss %{y:.4f}&lt;extra&gt;&lt;/extra&gt;","name":"loss"}],"layout":{"title":{"text":"Training loss, 41-parameter MLP on 4 points"},"xaxis":{"title":"step"},"yaxis":{"title":"sum-of-squares loss","type":"log"},"height":420,"margin":{"l":70,"r":30,"t":60,"b":50},"showlegend":false}}
</code></pre> <p>3.76 → 0.015 in 50 steps. Steep drop for the first ~10 steps as the network finds the rough direction, then a slow log-linear decline as it refines. Every gradient was computed by my own code.</p> <h2 id="why-scalar-valued">Why scalar-valued</h2> <p>PyTorch operates on tensors and broadcasts the same chain rule across millions of elements per op. The math is identical; tensor ops just batch it. Once the scalar engine works, the upgrade to a tensor engine is engineering, not concept.</p> <p>Code: <a href="https://github.com/debtirthasaha/micrograd-from-scratch">github.com/debtirthasaha/micrograd-from-scratch</a>.</p>]]></content><author><name></name></author><category term="deep-learning"/><category term="autograd"/><category term="backprop"/><summary type="html"><![CDATA[A 150-line autograd engine that supports +, *, **, tanh, exp, and a tiny MLP on top.]]></summary></entry><entry><title type="html">A transformer that reads C++ and writes Python</title><link href="https://debtirthasaha.github.io/blog/2026/cpp-to-python-transformer/" rel="alternate" type="text/html" title="A transformer that reads C++ and writes Python"/><published>2026-03-01T10:00:00+00:00</published><updated>2026-03-01T10:00:00+00:00</updated><id>https://debtirthasaha.github.io/blog/2026/cpp-to-python-transformer</id><content type="html" xml:base="https://debtirthasaha.github.io/blog/2026/cpp-to-python-transformer/"><![CDATA[<p>Encoder-decoder transformer that takes C++ source and emits the equivalent Python. Trained on XLCoST (a corpus of competitive-programming problems with parallel solutions in seven languages). 16.4M parameters. Best checkpoint at epoch 19, val_loss <strong>2.0474</strong>, sized to fit a GTX 1650 4 GB.</p> <p>Most “transformer from scratch” implementations are English → French. This is C++ → Python, which makes the input distribution different (structured code, lots of repeated tokens, hard syntax) and exposes a tokenization problem the dataset’s stock format papers over.</p> <h2 id="problem-setup">Problem setup</h2> <p>XLCoST ships parallel source files. Pairs look like:</p> <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// C++</span>
<span class="kt">int</span> <span class="nf">binary_search</span><span class="p">(</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="kt">int</span> <span class="n">x</span><span class="p">)</span> <span class="p">{</span>
    <span class="kt">int</span> <span class="n">lo</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">hi</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
    <span class="k">while</span> <span class="p">(</span><span class="n">lo</span> <span class="o">&lt;=</span> <span class="n">hi</span><span class="p">)</span> <span class="p">{</span>
        <span class="kt">int</span> <span class="n">mid</span> <span class="o">=</span> <span class="p">(</span><span class="n">lo</span> <span class="o">+</span> <span class="n">hi</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">;</span>
        <span class="k">if</span> <span class="p">(</span><span class="n">a</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">==</span> <span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">mid</span><span class="p">;</span>
        <span class="k">if</span> <span class="p">(</span><span class="n">a</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">x</span><span class="p">)</span> <span class="n">lo</span> <span class="o">=</span> <span class="n">mid</span> <span class="o">+</span> <span class="mi">1</span><span class="p">;</span>
        <span class="k">else</span> <span class="n">hi</span> <span class="o">=</span> <span class="n">mid</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
    <span class="p">}</span>
    <span class="k">return</span> <span class="o">-</span><span class="mi">1</span><span class="p">;</span>
<span class="p">}</span>
</code></pre></div></div> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">binary_search</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
    <span class="n">lo</span><span class="p">,</span> <span class="n">hi</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="nf">len</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
    <span class="k">while</span> <span class="n">lo</span> <span class="o">&lt;=</span> <span class="n">hi</span><span class="p">:</span>
        <span class="n">mid</span> <span class="o">=</span> <span class="p">(</span><span class="n">lo</span> <span class="o">+</span> <span class="n">hi</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span>
        <span class="k">if</span> <span class="n">a</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">==</span> <span class="n">x</span><span class="p">:</span> <span class="k">return</span> <span class="n">mid</span>
        <span class="k">if</span> <span class="n">a</span><span class="p">[</span><span class="n">mid</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">x</span><span class="p">:</span> <span class="n">lo</span> <span class="o">=</span> <span class="n">mid</span> <span class="o">+</span> <span class="mi">1</span>
        <span class="k">else</span><span class="p">:</span> <span class="n">hi</span> <span class="o">=</span> <span class="n">mid</span> <span class="o">-</span> <span class="mi">1</span>
    <span class="k">return</span> <span class="o">-</span><span class="mi">1</span>
</code></pre></div></div> <p>An encoder-decoder transformer is the right shape: the encoder reads the entire C++ source bidirectionally, the decoder generates Python autoregressively with cross-attention into the encoder’s output. This is the original 2017 architecture.</p> <h2 id="hyperparameters-sized-to-4-gb">Hyperparameters, sized to 4 GB</h2> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>d_model         = 256
N (layers)      = 4
h (heads)       = 8
d_ff            = 512
dropout         = 0.1
label_smoothing = 0.1
max_seq_len     = 350
batch_size      = 8   (with dynamic padding)
</code></pre></div></div> <p>16.4M parameters. The constraints are real: with <code class="language-plaintext highlighter-rouge">d_model = 512</code> and full padding to 350 tokens, the VRAM math doesn’t close. Every knob got pulled until the model both fit and trained.</p> <p>VRAM math, roughly: <code class="language-plaintext highlighter-rouge">B × T × T × h × 4 bytes</code> for attention scores per layer. At <code class="language-plaintext highlighter-rouge">B=8, T=350, h=8</code>: <code class="language-plaintext highlighter-rouge">8 × 350² × 8 × 4 = ~31 MB</code> <em>per layer</em>, and you need this for forward and backward. Across 8 layers (4 encoder + 4 decoder) and other activations it adds up fast. Dynamic padding (pad each batch to its longest sequence rather than to <code class="language-plaintext highlighter-rouge">max_seq_len</code>) is what made training viable.</p> <h2 id="the-dataset-problem">The dataset problem</h2> <p>XLCoST is pre-tokenized. The files look like:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>int NEW_LINE binary_search ( vector &lt; int &gt; &amp; a , int x ) { NEW_LINE INDENT ...
</code></pre></div></div> <p><code class="language-plaintext highlighter-rouge">NEW_LINE</code>, <code class="language-plaintext highlighter-rouge">INDENT</code>, <code class="language-plaintext highlighter-rouge">DEDENT</code> are XLCoST’s whitespace-preserving tokens. Splitting on whitespace gives you tokens directly — no tokenizer required.</p> <p>This is fine for training and evaluation on XLCoST. It is not fine for inference on real code, because real code doesn’t ship with <code class="language-plaintext highlighter-rouge">NEW_LINE</code> tokens.</p> <p>So a second tokenizer had to be built: a raw C++ tokenizer that takes ordinary source and produces the XLCoST tokenization. It handles:</p> <ul> <li>Comments (<code class="language-plaintext highlighter-rouge">//</code>, <code class="language-plaintext highlighter-rouge">/* */</code>) — stripped before tokenization</li> <li>String literals — preserved as single tokens (don’t split inside quoted strings)</li> <li>Multi-char operators (<code class="language-plaintext highlighter-rouge">&lt;&lt;</code>, <code class="language-plaintext highlighter-rouge">&gt;&gt;</code>, <code class="language-plaintext highlighter-rouge">==</code>, <code class="language-plaintext highlighter-rouge">!=</code>, <code class="language-plaintext highlighter-rouge">&lt;=</code>, <code class="language-plaintext highlighter-rouge">&gt;=</code>, <code class="language-plaintext highlighter-rouge">&amp;&amp;</code>, <code class="language-plaintext highlighter-rouge">||</code>, <code class="language-plaintext highlighter-rouge">++</code>, <code class="language-plaintext highlighter-rouge">--</code>, <code class="language-plaintext highlighter-rouge">-&gt;</code>, <code class="language-plaintext highlighter-rouge">::</code>, <code class="language-plaintext highlighter-rouge">+=</code>, etc.) — match greedy</li> <li>Numbers, identifiers — match maximally</li> <li>Whitespace → <code class="language-plaintext highlighter-rouge">NEW_LINE</code>, <code class="language-plaintext highlighter-rouge">INDENT</code>, <code class="language-plaintext highlighter-rouge">DEDENT</code> based on column position</li> </ul> <p>The inference path is <code class="language-plaintext highlighter-rouge">raw C++ → my tokenizer → XLCoST tokens → model → Python tokens → join</code>.</p> <h2 id="vocabulary-coverage-and-unks">Vocabulary coverage and UNKs</h2> <p>Vocab is built from the training set with <code class="language-plaintext highlighter-rouge">min_freq=2</code> — any token appearing fewer than 2 times is replaced with <code class="language-plaintext highlighter-rouge">&lt;UNK&gt;</code>. Final vocab is ~12K source tokens, ~10K target tokens.</p> <p>This means common things work and uncommon things fail. <code class="language-plaintext highlighter-rouge">binary_search</code> is in vocab. <code class="language-plaintext highlighter-rouge">Hello, World!\n</code> is not — the string literal <code class="language-plaintext highlighter-rouge">"Hello, World!\n"</code> is a single rare token, gets mapped to <code class="language-plaintext highlighter-rouge">&lt;UNK&gt;</code>, and the model has no signal to translate it. You can confirm this by tokenizing <code class="language-plaintext highlighter-rouge">cout &lt;&lt; "Hello, World!" &lt;&lt; endl;</code> and watching the string vanish into an UNK.</p> <p>For competitive-programming-style code (loops, arrays, recursion, math) coverage is good and translation is fluent. For anything string-heavy it falls apart.</p> <h2 id="the-architecture-12-components">The architecture, 12 components</h2> <p>All twelve are in <code class="language-plaintext highlighter-rouge">model.py</code>. Quick map:</p> <ol> <li><code class="language-plaintext highlighter-rouge">InputEmbeddings</code> — <code class="language-plaintext highlighter-rouge">nn.Embedding(vocab_size, d_model)</code>, output scaled by <code class="language-plaintext highlighter-rouge">sqrt(d_model)</code></li> <li><code class="language-plaintext highlighter-rouge">PositionalEncoding</code> — sinusoidal, fixed (not learned)</li> <li><code class="language-plaintext highlighter-rouge">LayerNormalization</code> — manual implementation with learnable γ, β</li> <li><code class="language-plaintext highlighter-rouge">FeedForwardBlock</code> — <code class="language-plaintext highlighter-rouge">Linear(d_model, d_ff) → ReLU → Dropout → Linear(d_ff, d_model)</code></li> <li><code class="language-plaintext highlighter-rouge">MultiHeadAttentionBlock</code> — Q/K/V projections, scaled-dot-product, output projection. Stores <code class="language-plaintext highlighter-rouge">attention_scores</code> as a buffer for later visualization.</li> <li><code class="language-plaintext highlighter-rouge">ResidualConnection</code> — <code class="language-plaintext highlighter-rouge">x + dropout(sublayer(norm(x)))</code> (pre-norm)</li> <li><code class="language-plaintext highlighter-rouge">EncoderBlock</code> — self-attention + FFN, each wrapped in residual</li> <li><code class="language-plaintext highlighter-rouge">Encoder</code> — stack of <code class="language-plaintext highlighter-rouge">N</code> encoder blocks + final LayerNorm</li> <li><code class="language-plaintext highlighter-rouge">DecoderBlock</code> — masked self-attention + cross-attention + FFN</li> <li><code class="language-plaintext highlighter-rouge">Decoder</code> — stack of <code class="language-plaintext highlighter-rouge">N</code> decoder blocks + final LayerNorm</li> <li><code class="language-plaintext highlighter-rouge">ProjectionLayer</code> — <code class="language-plaintext highlighter-rouge">Linear(d_model, vocab_size)</code>, no softmax (cross-entropy applies it internally)</li> <li><code class="language-plaintext highlighter-rouge">Transformer</code> — encoder + decoder + source/target embeddings + source/target positional + projection</li> </ol> <p>Pre-norm everywhere. Output of the projection layer is logits, not log-softmax — <code class="language-plaintext highlighter-rouge">nn.CrossEntropyLoss</code> expects logits and applies log-softmax internally for numerical stability.</p> <p>The cross-attention in the decoder is where the two halves meet: decoder <code class="language-plaintext highlighter-rouge">Q</code> comes from the decoder’s own residual stream, but <code class="language-plaintext highlighter-rouge">K</code> and <code class="language-plaintext highlighter-rouge">V</code> come from the encoder’s final output. Each decoder position queries the encoded source to decide what to translate next.</p> <h2 id="training-warmup--label-smoothing">Training: warmup + label smoothing</h2> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">optimizer</span> <span class="o">=</span> <span class="nc">Adam</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span> <span class="n">betas</span><span class="o">=</span><span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.98</span><span class="p">),</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-9</span><span class="p">)</span>
<span class="n">scheduler</span> <span class="o">=</span> <span class="nc">LambdaLR</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span>
                     <span class="k">lambda</span> <span class="n">step</span><span class="p">:</span> <span class="n">d_model</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="nf">min</span><span class="p">(</span><span class="n">step</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">step</span> <span class="o">*</span> <span class="n">warmup_steps</span> <span class="o">**</span> <span class="o">-</span><span class="mf">1.5</span><span class="p">))</span>
<span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="nc">CrossEntropyLoss</span><span class="p">(</span><span class="n">ignore_index</span><span class="o">=</span><span class="n">PAD</span><span class="p">,</span> <span class="n">label_smoothing</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span>
</code></pre></div></div> <p>The LR schedule is from the original transformer paper: linear warmup for <code class="language-plaintext highlighter-rouge">warmup_steps</code>, then <code class="language-plaintext highlighter-rouge">1/sqrt(step)</code> decay. <code class="language-plaintext highlighter-rouge">ignore_index=PAD</code> masks padding tokens from the loss. <code class="language-plaintext highlighter-rouge">label_smoothing=0.1</code> gives 10% of the probability mass to non-target tokens uniformly — softens the optimization target and regularizes.</p> <p>Greedy decoding for inference. No beam search.</p> <h2 id="results">Results</h2> <p>20 epochs. Train and val loss every epoch:</p> <table> <thead> <tr> <th>Epoch</th> <th>Train</th> <th>Val</th> </tr> </thead> <tbody> <tr> <td>13</td> <td>1.9109</td> <td>2.0615</td> </tr> <tr> <td>15</td> <td>1.8708</td> <td>2.0545</td> </tr> <tr> <td>16</td> <td>1.8542</td> <td>2.0511</td> </tr> <tr> <td>19</td> <td>1.8103</td> <td><strong>2.0474</strong> ← best</td> </tr> <tr> <td>20</td> <td>1.7964</td> <td>2.0576 ← train still dropping, val rising</td> </tr> </tbody> </table> <p>Best checkpoint at epoch 19. Train kept dropping past that point but val stopped — classic overfitting signature. The save-best-by-val-loss logic kept epoch 19 as <code class="language-plaintext highlighter-rouge">best_model.pt</code>.</p> <p>Sample translation:</p> <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kt">int</span> <span class="nf">binary_search</span><span class="p">(</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&amp;</span> <span class="n">a</span><span class="p">,</span> <span class="kt">int</span> <span class="n">x</span><span class="p">)</span> <span class="p">{</span>
    <span class="kt">int</span> <span class="n">lo</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="n">hi</span> <span class="o">=</span> <span class="n">a</span><span class="p">.</span><span class="n">size</span><span class="p">()</span> <span class="o">-</span> <span class="mi">1</span><span class="p">;</span>
    <span class="p">...</span>
<span class="p">}</span>
</code></pre></div></div> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">binary_search</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
    <span class="n">lo</span><span class="p">,</span> <span class="n">hi</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="nf">len</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span>
    <span class="k">while</span> <span class="n">lo</span> <span class="o">&lt;=</span> <span class="n">hi</span><span class="p">:</span>
        <span class="n">mid</span> <span class="o">=</span> <span class="p">(</span><span class="n">lo</span> <span class="o">+</span> <span class="n">hi</span><span class="p">)</span> <span class="o">//</span> <span class="mi">2</span>
        <span class="bp">...</span>
</code></pre></div></div> <p>Hello-world-style code with rare string literals fails: <code class="language-plaintext highlighter-rouge">cout &lt;&lt; "Hello"</code> produces <code class="language-plaintext highlighter-rouge">print(&lt;UNK&gt;)</code>. Loops, math, recursion, array indexing all translate cleanly.</p> <h2 id="what-the-model-actually-learned">What the model actually learned</h2> <p>Inspired by <a href="https://transformer-circuits.pub/2026/nla/index.html#introduction">Anthropic’s circuits work</a>, I loaded the checkpoint and probed two things: the token embedding matrices on each side, and the attention pattern of the last encoder layer.</p> <h3 id="embedding-nearest-neighbors">Embedding nearest neighbors</h3> <p>For a seed token, find the closest tokens in embedding space by cosine similarity. Both the source-side (C++) and target-side (Python) embedding tables show clean semantic structure even though no one taught the model what these tokens <em>mean</em>.</p> <p>C++ side (top 6 per row):</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>int    -&gt; endl, ll, ;, [EOS], long, &lt;&lt;
for    -&gt; while, memset, getline, case, faces, sortRowWise
if     -&gt; ==, 127, while, case, break, fast
vector -&gt; calloc, begin, multiset, sizeof, NthPostordernode, word_size
&lt;      -&gt; &gt;, ::, %, &lt;=, &amp;, #
==     -&gt; case, 127, &lt;=, if, checkAbundant, !=
true   -&gt; Magic, False, ||, True, slope3, npos
string -&gt; char, "4", chanceA, 122, modifyString, findNumberOfLIS
</code></pre></div></div> <p>Python side (top 6 per row):</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>def    -&gt; class, divTermCount, for, in, NEW_LINE, Euler
if     -&gt; elif(0.70), while, or, and, isPower2, checkPerfectcube
for    -&gt; while, in, def, [, range, within
range  -&gt; in, while, sqrt, ord, int, xrange
+      -&gt; +=, -, &gt;&gt;=, -=, &lt;&lt;=, &gt;
==     -&gt; !=(0.67), &gt;=, &gt;, &lt;=, &lt;, than
print  -&gt; return, PrintList, format, Squares, cout, round
&lt;      -&gt; &gt;(0.65), &gt;=, &lt;=, ==, &lt;, -&gt;
True   -&gt; False, 82, 3.14159265, 0.25, 4.5, None
str    -&gt; acos, log2, trailingZero, int, string, singlePrimeFactor
</code></pre></div></div> <p>A few clusters that aren’t accidents:</p> <ul> <li><strong>Comparison operators</strong>. On the Python side, <code class="language-plaintext highlighter-rouge">&lt;</code> is nearest to <code class="language-plaintext highlighter-rouge">&gt;</code> (0.65), then <code class="language-plaintext highlighter-rouge">&gt;=</code> (0.63), <code class="language-plaintext highlighter-rouge">&lt;=</code> (0.63), <code class="language-plaintext highlighter-rouge">==</code> (0.53), <code class="language-plaintext highlighter-rouge">&lt;</code> (0.51), <code class="language-plaintext highlighter-rouge">-&gt;</code> — a tight cluster of every binary comparison the model has seen.</li> <li><strong>Boolean values</strong>. <code class="language-plaintext highlighter-rouge">True</code> finds <code class="language-plaintext highlighter-rouge">False</code> and <code class="language-plaintext highlighter-rouge">None</code>. <code class="language-plaintext highlighter-rouge">true</code> (on the C++ side) finds <code class="language-plaintext highlighter-rouge">True</code>, <code class="language-plaintext highlighter-rouge">False</code>, and <code class="language-plaintext highlighter-rouge">||</code>. The model puts truth values close together regardless of casing or language.</li> <li><strong>Control flow</strong>. <code class="language-plaintext highlighter-rouge">if</code> on the Python side has <code class="language-plaintext highlighter-rouge">elif</code> as its nearest neighbor at cosine 0.70 — by a clear margin. <code class="language-plaintext highlighter-rouge">for</code> is nearest to <code class="language-plaintext highlighter-rouge">while</code>. <code class="language-plaintext highlighter-rouge">def</code> is nearest to <code class="language-plaintext highlighter-rouge">class</code>.</li> <li><strong>Cross-language synonyms</strong>. <code class="language-plaintext highlighter-rouge">print</code> (Python) has <code class="language-plaintext highlighter-rouge">cout</code> in its top-6. The model learned that the C++ side’s <code class="language-plaintext highlighter-rouge">cout</code> and the Python side’s <code class="language-plaintext highlighter-rouge">print</code> play structurally similar roles, even though they live in different vocabularies and different embedding tables.</li> <li><strong>C++ integer family</strong>. <code class="language-plaintext highlighter-rouge">int</code> is nearest to <code class="language-plaintext highlighter-rouge">long</code> and <code class="language-plaintext highlighter-rouge">ll</code> (the <code class="language-plaintext highlighter-rouge">typedef long long ll</code> shorthand competitive programmers use). The model picked up that these are interchangeable integer types.</li> </ul> <p>None of this is taught explicitly. The supervision signal is a cross-entropy loss on next-token prediction in a sequence-to-sequence setup. Semantically related tokens end up close together because the loss is lower when interchangeable tokens have similar representations.</p> <h3 id="encoder-attention-on-int-sum--a--b--new_line-return-sum-">Encoder attention on <code class="language-plaintext highlighter-rouge">int sum = a + b ; NEW_LINE return sum ;</code></h3> <p>Pulling out head 0 of the last encoder layer’s self-attention on one short example. Rows = query positions, columns = key positions. Hover for exact weights:</p> <pre><code class="language-plotly">{"data":[{"z":[[0.02,0.00,0.04,0.00,0.00,0.00,0.16,0.57,0.02,0.00,0.14,0.07],[0.00,0.00,0.02,0.00,0.00,0.00,0.09,0.76,0.01,0.00,0.07,0.05],[0.01,0.00,0.06,0.00,0.00,0.00,0.11,0.64,0.03,0.00,0.12,0.02],[0.01,0.00,0.05,0.00,0.00,0.00,0.16,0.57,0.02,0.00,0.16,0.02],[0.01,0.00,0.03,0.00,0.00,0.00,0.10,0.75,0.01,0.00,0.08,0.03],[0.01,0.00,0.06,0.00,0.00,0.00,0.18,0.51,0.03,0.00,0.18,0.03],[0.03,0.00,0.06,0.00,0.00,0.00,0.19,0.41,0.05,0.00,0.16,0.09],[0.07,0.00,0.08,0.00,0.01,0.00,0.22,0.26,0.11,0.00,0.18,0.08],[0.02,0.00,0.04,0.00,0.00,0.00,0.17,0.52,0.03,0.00,0.15,0.08],[0.01,0.00,0.02,0.00,0.00,0.00,0.09,0.75,0.01,0.00,0.08,0.05],[0.03,0.00,0.06,0.00,0.00,0.00,0.18,0.46,0.04,0.00,0.15,0.08],[0.05,0.00,0.05,0.00,0.00,0.00,0.23,0.31,0.06,0.00,0.18,0.11]],"x":["int","sum","=","a","+","b",";","NEW_LINE","return","sum",";","[EOS]"],"y":["int","sum","=","a","+","b",";","NEW_LINE","return","sum",";","[EOS]"],"type":"heatmap","colorscale":"Viridis","hovertemplate":"q: %{y}&lt;br&gt;k: %{x}&lt;br&gt;weight: %{z:.2f}&lt;extra&gt;&lt;/extra&gt;","colorbar":{"title":{"text":"attn"}}}],"layout":{"title":{"text":"Encoder self-attention, head 0, last layer"},"xaxis":{"title":"key (attended to)","side":"top"},"yaxis":{"title":"query (attending)","autorange":"reversed"},"height":500,"margin":{"l":80,"r":30,"t":90,"b":50}}}
</code></pre> <p>In text form:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>              int   sum    =     a     +     b     ;   NEW_  ret   sum    ;   [EOS]
   int    [ 0.02  0.00  0.04  0.00  0.00  0.00  0.16  0.57  0.02  0.00  0.14  0.07 ]
   sum    [ 0.00  0.00  0.02  0.00  0.00  0.00  0.09  0.76  0.01  0.00  0.07  0.05 ]
     =    [ 0.01  0.00  0.06  0.00  0.00  0.00  0.11  0.64  0.03  0.00  0.12  0.02 ]
     a    [ 0.01  0.00  0.05  0.00  0.00  0.00  0.16  0.57  0.02  0.00  0.16  0.02 ]
     +    [ 0.01  0.00  0.03  0.00  0.00  0.00  0.10  0.75  0.01  0.00  0.08  0.03 ]
     b    [ 0.01  0.00  0.06  0.00  0.00  0.00  0.18  0.51  0.03  0.00  0.18  0.03 ]
     ;    [ 0.03  0.00  0.06  0.00  0.00  0.00  0.19  0.41  0.05  0.00  0.16  0.09 ]
NEW_LI    [ 0.07  0.00  0.08  0.00  0.01  0.00  0.22  0.26  0.11  0.00  0.18  0.08 ]
return    [ 0.02  0.00  0.04  0.00  0.00  0.00  0.17  0.52  0.03  0.00  0.15  0.08 ]
   sum    [ 0.01  0.00  0.02  0.00  0.00  0.00  0.09  0.75  0.01  0.00  0.08  0.05 ]
     ;    [ 0.03  0.00  0.06  0.00  0.00  0.00  0.18  0.46  0.04  0.00  0.15  0.08 ]
 [EOS]    [ 0.05  0.00  0.05  0.00  0.00  0.00  0.23  0.31  0.06  0.00  0.18  0.11 ]
</code></pre></div></div> <p>The argmax-key per query position:</p> <div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>q[ 0] int       -&gt; k[ 7] NEW_LINE  (0.57)
q[ 1] sum       -&gt; k[ 7] NEW_LINE  (0.76)
q[ 2] =         -&gt; k[ 7] NEW_LINE  (0.64)
q[ 3] a         -&gt; k[ 7] NEW_LINE  (0.57)
q[ 4] +         -&gt; k[ 7] NEW_LINE  (0.75)
q[ 5] b         -&gt; k[ 7] NEW_LINE  (0.51)
q[ 6] ;         -&gt; k[ 7] NEW_LINE  (0.41)
q[ 7] NEW_LINE  -&gt; k[ 7] NEW_LINE  (0.26)
q[ 8] return    -&gt; k[ 7] NEW_LINE  (0.52)
q[ 9] sum       -&gt; k[ 7] NEW_LINE  (0.75)
q[10] ;         -&gt; k[ 7] NEW_LINE  (0.46)
q[11] [EOS]     -&gt; k[ 7] NEW_LINE  (0.31)
</code></pre></div></div> <p>Every single position is attending most heavily to position 7 — the <code class="language-plaintext highlighter-rouge">NEW_LINE</code> statement boundary. The mass on that one column ranges from 0.26 (the boundary attending to itself) to 0.76 (<code class="language-plaintext highlighter-rouge">sum</code> and <code class="language-plaintext highlighter-rouge">+</code> attending to the boundary). Other columns are near-zero almost everywhere.</p> <p>This head has specialized into something like a <em>statement-end aggregator</em>: route information from anywhere in the current statement to the boundary marker that closes it. Cross-attention into the decoder then has a privileged column at <code class="language-plaintext highlighter-rouge">NEW_LINE</code> that has gathered everything about the C++ statement, and the decoder can read from it to emit the Python equivalent. The reason XLCoST’s pre-tokenization scheme produces these explicit boundary tokens is exactly so that the model has somewhere to put statement-level information. Head 0 of the last encoder layer is using them for that.</p> <p>Other heads in the same layer attend differently (some local, some diagonal, some on operators) — the specialization isn’t uniform. But this one’s job is clear, and is the kind of mechanistic finding that motivates the circuits-style probing in the first place.</p> <h2 id="what-would-close-the-gap">What would close the gap</h2> <ul> <li><strong>Subword tokenization</strong> (BPE on the raw source) instead of per-word vocab + UNKs. The whole “rare string literal” failure goes away.</li> <li><strong>Bigger model</strong> if you have the VRAM. <code class="language-plaintext highlighter-rouge">d_model=512</code> and 6 layers is the standard small-transformer scale, but doesn’t fit at <code class="language-plaintext highlighter-rouge">T=350</code> on 4 GB.</li> <li><strong>Beam search</strong> at decode time. Greedy is fine for code but a beam of 4 reliably picks better completions for long sequences.</li> </ul> <p>Code: <a href="https://github.com/debtirthasaha/cpp-to-python-transformer">github.com/debtirthasaha/cpp-to-python-transformer</a>. The 16 numbered tests in <code class="language-plaintext highlighter-rouge">test_step*.py</code> build up each of the 12 components in isolation before the full model is assembled. Trained checkpoint (189 MB) is on Hugging Face at <a href="https://huggingface.co/MR0b0t/cpp-to-python-transformer">MR0b0t/cpp-to-python-transformer</a>.</p>]]></content><author><name></name></author><category term="deep-learning"/><category term="transformer"/><category term="translation"/><category term="code"/><summary type="html"><![CDATA[16.4M-parameter encoder-decoder transformer for C++ → Python code translation, trained on XLCoST on a 4 GB GPU. val_loss 2.0474.]]></summary></entry></feed>