Eight A100s, $61, and 124M parameters
End-to-end reproduction of Karpathy’s “Let’s reproduce GPT-2 (124M)” video. The build started on a 4 GB GTX 1650 and finished on 8× A100 SXM4 rented from Lambda. Training ran for 19,073 steps over 10B FineWeb-Edu tokens. Total cost: $61.23. Final val loss 3.40 vs OpenAI’s released baseline of 3.29 — a 97% match. HellaSwag 26.99% vs 29.45% baseline.
| Metric | Start | End | OpenAI baseline | % of target |
|---|---|---|---|---|
| Val loss | 10.951 | 3.3969 | 3.292 | 97% |
| HellaSwag acc_norm | 24.82% | 26.99% | 29.45% | 91% |
| Sustained throughput | — | 1.1M tokens/sec | — | — |
| Steps/min | — | ~127 | — | — |
The val loss gap is small enough that another ~$80 of training would close it. The HellaSwag gap is the well-known FineWeb-Edu artifact — high-quality educational text gives slightly worse commonsense reasoning per perplexity unit than OpenAI’s WebText.
The actual training run, with the OpenAI baseline as a reference line:
{"data":[{"x":[0,50,100,150,200,250,300,350,400,450,500,550,600,650,700,750,800,850,900,950,1000,1050,1100,1150,1200,1250,1300,1350,1400,1450,1500,1550,1600,1650,1700,1750,1800,1850,1900,1950,2000,2050,2100,2150,2200,2250,2300,2350,2400,2450,2500,2550,2600,2650,2700,2750,2800,2850,2900,2950,3000,3050,3100,3150,3200,3250,3300,3350,3400,3450,3500,3550,3600,3650,3700,3750,3800,3850,3900,3950,4000,4050,4100,4150,4200,4250,4300,4350,4400,4450,4500,4550,4600,4650,4700,4750,4800,4850,4900,4950,5000,5050,5100,5150,5200,5250,5300,5350,5400,5450,5500,5550,5600,5650,5700,5750,5800,5850,5900,5950,6000,6050,6100,6150,6200,6250,6300,6350,6400,6450,6500,6550,6600,6650,6700,6750,6800,6850,6900,6950,7000,7050,7100,7150,7200,7250,7300,7350,7400,7450,7500,7550,7600,7650,7700,7750,7800,7850,7900,7950,8000,8050,8100,8150,8200,8250,8300,8350,8400,8450,8500,8550,8600,8650,8700,8750,8800,8850,8900,8950,9000,9050,9100,9150,9200,9250,9300,9350,9400,9450,9500,9550,9600,9650,9700,9750,9800,9850,9900,9950,10000,10050,10100,10150,10200,10250,10300,10350,10400,10450,10500,10550,10600,10650,10700,10750,10800,10850,10900,10950,11000,11050,11100,11150,11200,11250,11300,11350,11400,11450,11500,11550,11600,11650,11700,11750,11800,11850,11900,11950,12000,12050,12100,12150,12200,12250,12300,12350,12400,12450,12500,12550,12600,12650,12700,12750,12800,12850,12900,12950,13000,13050,13100,13150,13200,13250,13300,13350,13400,13450,13500,13550,13600,13650,13700,13750,13800,13850,13900,13950,14000,14050,14100,14150,14200,14250,14300,14350,14400,14450,14500,14550,14600,14650,14700,14750,14800,14850,14900,14950,15000,15050,15100,15150,15200,15250,15300,15350,15400,15450,15500,15550,15600,15650,15700,15750,15800,15850,15900,15950,16000,16050,16100,16150,16200,16250,16300,16350,16400,16450,16500,16550,16600,16650,16700,16750,16800,16850,16900,16950,17000,17050,17100,17150,17200,17250,17300,17350,17400,17450,17500,17550,17600,17650,17700,17750,17800,17850,17900,17950,18000,18050,18100,18150,18200,18250,18300,18350,18400,18450,18500,18550,18600,18650,18700,18750,18800,18850,18900,18950,19000,19050],"y":[10.955029,8.652573,7.339675,6.86025,6.542106,6.387939,6.357413,6.096048,5.967309,5.876346,5.71647,5.620126,5.435894,5.515324,5.390418,5.392198,5.161924,5.103988,4.995365,4.951733,4.926783,4.729811,4.587543,4.550589,4.48293,4.45039,4.537868,4.524739,4.583992,4.460568,4.444868,4.381177,4.368808,4.243211,4.285657,4.238783,4.14917,4.080226,4.098678,4.060075,4.298099,4.237405,4.191657,4.160752,4.198476,4.104283,4.108762,4.067921,4.116701,3.930878,3.955742,3.91789,3.876447,4.088891,4.056384,4.056029,4.073021,4.063436,4.069367,4.050061,4.010026,3.995684,3.911222,3.906269,3.913971,3.918218,3.8181,3.854215,3.778675,4.010101,3.959394,3.957729,3.982995,3.894079,3.907338,3.947609,3.883737,3.910084,3.870479,3.850236,3.727764,3.670568,3.701516,3.90143,3.958811,3.923477,3.937549,3.883154,3.868949,3.855658,3.828887,3.774447,3.781784,3.798658,3.635849,3.673353,3.645618,3.586813,3.888197,3.844004,3.799917,3.815917,3.847712,3.832865,3.779544,3.74312,3.729076,3.561608,3.449939,3.528541,3.683025,3.799625,3.81723,3.845233,3.772759,3.711059,3.710322,3.754577,3.779361,3.760361,3.777559,3.711643,3.720454,3.777155,3.65674,3.782415,3.646935,3.719334,3.650414,3.724901,3.751636,3.754004,3.742701,3.696117,3.698276,3.68034,3.759598,3.708417,3.695954,3.631462,3.729955,3.715655,3.698024,3.702647,3.629437,3.723177,3.659523,3.68377,3.655551,3.658967,3.7362,3.69976,3.615459,3.622903,3.649003,3.639076,3.662286,3.650483,3.670292,3.651448,3.695603,3.68487,3.654457,3.708802,3.614856,3.69809,3.693779,3.759378,3.655924,3.603687,3.635201,3.577073,3.665966,3.597528,3.628942,3.662407,3.59794,3.680565,3.593917,3.621378,3.630538,3.659089,3.641257,3.618741,3.630788,3.586698,3.624945,3.674439,3.589103,3.632992,3.573211,3.668649,3.576148,3.580356,3.653369,3.636115,3.52358,3.505952,3.577158,3.549963,3.615004,3.577269,3.568972,3.564436,3.570774,3.634584,3.541682,3.667766,3.649652,3.562427,3.635001,3.523956,3.58213,3.502887,3.647513,3.51616,3.524837,3.57378,3.494234,3.525831,3.57591,3.551852,3.528437,3.449448,3.544546,3.442327,3.526242,3.499907,3.521501,3.529324,3.462544,3.582874,3.455339,3.536372,3.499629,3.504132,3.533975,3.490507,3.533957,3.489776,3.557726,3.498231,3.519463,3.488743,3.457328,3.46037,3.476739,3.496565,3.492133,3.472932,3.479477,3.561807,3.48808,3.509819,3.694302,3.460399,3.524822,3.494879,3.392873,3.560445,3.419302,3.431916,3.512982,3.439761,3.548496,3.439065,3.460358,3.483511,3.41854,3.472522,3.46191,3.484112,3.439632,3.421415,3.482553,3.374554,3.473166,3.405603,3.465045,3.494958,3.42211,3.420622,3.453494,3.398499,3.427082,3.472697,3.431491,3.40042,3.439947,3.439783,3.438387,3.458462,3.404739,3.435358,3.447509,3.449286,3.387485,3.45961,3.476094,3.443756,3.433138,3.329407,3.445033,3.35931,3.425076,3.378695,3.468765,3.427226,3.371132,3.437642,3.390105,3.415966,3.414496,3.380322,3.463654,3.346844,3.372707,3.376734,3.404138,3.317891,3.375628,3.411178,3.285608,3.410325,3.337027,3.41946,3.418103,3.326298,3.391832,3.404434,3.378114,3.332703,3.339972,3.366685,3.418615,3.384357,3.433587,3.384535,3.312392,3.38399,3.330543,3.428728,3.399785,3.389828,3.414865,3.348916,3.35787,3.314832,3.453763,3.406743,3.298486,3.358698,3.260078,3.487381,3.320862,3.405267,3.411797,3.287154,3.387619,3.278841,3.413348,3.371803,3.388267,3.352105,3.261056,3.370182,3.351461,3.389376,3.346794,3.29316,3.40975,3.297334,3.394829,3.381837,3.26427,3.313486,3.348226,3.354343,3.289057,3.260655,3.272449,3.453633],"name":"train loss","type":"scatter","mode":"lines","line":{"color":"rgba(99,110,250,0.35)","width":1},"hovertemplate":"step %{x}<br>train %{y:.3f}<extra></extra>"},{"x":[0,250,500,750,1000,1250,1500,1750,2000,2250,2500,2750,3000,3250,3500,3750,4000,4250,4500,4750,5000,5250,5500,5750,6000,6250,6500,6750,7000,7250,7500,7750,8000,8250,8500,8750,9000,9250,9500,9750,10000,10250,10500,10750,11000,11250,11500,11750,12000,12250,12500,12750,13000,13250,13500,13750,14000,14250,14500,14750,15000,15250,15500,15750,16000,16250,16500,16750,17000,17250,17500,17750,18000,18250,18500,18750,19000,19072],"y":[10.9512,6.43,5.824,5.3065,4.8817,4.5899,4.4297,4.3199,4.2401,4.171,4.1257,4.0726,4.0291,3.9984,3.9699,3.9341,3.911,3.8882,3.8702,3.8511,3.83,3.813,3.8044,3.782,3.7661,3.7459,3.7313,3.7218,3.7073,3.6976,3.6843,3.6784,3.6638,3.6543,3.6425,3.6341,3.6258,3.616,3.607,3.5969,3.5895,3.5786,3.5725,3.5648,3.5544,3.5483,3.5396,3.5319,3.5236,3.5161,3.5099,3.5031,3.4963,3.4897,3.485,3.4792,3.4719,3.4666,3.4611,3.4558,3.4511,3.4459,3.4418,3.4376,3.4329,3.4287,3.4259,3.4222,3.4196,3.4169,3.4137,3.4108,3.4082,3.406,3.4044,3.4023,3.3994,3.3969],"name":"val loss","type":"scatter","mode":"lines+markers","line":{"color":"#EF553B","width":2.5},"marker":{"size":4},"hovertemplate":"step %{x}<br>val %{y:.3f}<extra></extra>"},{"x":[0,19072],"y":[3.292,3.292],"name":"OpenAI GPT-2 124M val baseline (3.292)","type":"scatter","mode":"lines","line":{"color":"#00cc96","width":1.5,"dash":"dash"},"hoverinfo":"skip"}],"layout":{"title":{"text":"Training and validation loss, 19,073 steps on FineWeb-Edu"},"xaxis":{"title":"step","range":[-100,19200]},"yaxis":{"title":"loss"},"height":460,"margin":{"l":60,"r":30,"t":60,"b":50},"hovermode":"x unified","legend":{"x":0.6,"y":0.95}}}
Model: matching HF’s GPT-2 byte-for-byte
The model class mirrors HuggingFace’s GPT2LMHeadModel parameter names exactly so that from_pretrained("gpt2") can copy weights in. 148 tensors per HF GPT-2 state_dict, in three groups:
- Embeddings.
transformer.wte.weight (50257, 768),transformer.wpe.weight (1024, 768). Position embeddings are learned, not sinusoidal — that was a 2018 GPT-2 choice. - 12 identical blocks. Each has pre-norm
ln_1, fused QKV projectionattn.c_attn.weight (768, 2304), output projectionattn.c_proj, pre-normln_2, MLP upmlp.c_fc.weight (768, 3072)(4× expansion), MLP downmlp.c_proj. GELU between the MLP linears. - Final.
transformer.ln_f(the extra LayerNorm GPT-2 added vs the original 2017 transformer) andlm_head.weight (50257, 768).
Two implementation gotchas:
HF weights are stored transposed. HuggingFace uses a TF-legacy Conv1D layer that stores (in, out)-shaped weights. Standard PyTorch nn.Linear stores (out, in). So when copying HF weights into our nn.Linear-based model, the four matrices c_attn, c_proj (attn), c_fc, c_proj (mlp) need .t(). The other tensors copy directly.
Weight tying. lm_head.weight and transformer.wte.weight are the same tensor in GPT-2 — one physical 38.6M-parameter matrix used both as input embedding and output classifier. That matrix alone is ~30% of the full 124M. Implementation: self.transformer.wte.weight = self.lm_head.weight. Two Python names, one tensor object.
Residual stream init
GPT-2 scales the output projection of each residual sub-layer at init by 1/sqrt(2*n_layer). With n_layer = 12, that’s a factor of ~0.204. Why: at each residual addition x = x + sub_layer(x), the variance of x grows. Without rescaling, variance compounds across 24 sub-layers (12 attn + 12 MLP) and the residual stream’s scale explodes by init time. The rescale keeps post-stack variance close to the input variance.
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02 * (2 * config.n_layer) ** -0.5)
Everything else is standard N(0, 0.02). This trick is from nanoGPT, not the GPT-2 paper, but it makes training much more stable.
Data: FineWeb-Edu, sharded
10B tokens of FineWeb-Edu (the educational subset), tokenized with tiktoken’s gpt2 encoding, written out as 100 shards of (100M tokens, uint16) .npy files. Tokenization is multiprocess and dominated by Python overhead — ~30 minutes on the 8× A100 instance, ~$26 just to tokenize.
uint16 because vocab size is 50257, comfortably under 65536. Cuts disk size in half.
DataLoaderLite keeps one shard in memory at a time, advances to the next when exhausted, and respects the DDP rank/world-size partitioning:
def reset(self):
self.current_shard = 0
self.tokens = load_tokens(self.shards[self.current_shard])
self.current_position = self.B * self.T * self.process_rank
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position : self.current_position + B*T + 1]
x = buf[:-1].view(B, T)
y = buf[1:].view(B, T)
self.current_position += B * T * self.num_processes
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
self.current_shard = (self.current_shard + 1) % len(self.shards)
self.tokens = load_tokens(self.shards[self.current_shard])
self.current_position = self.B * self.T * self.process_rank
return x, y
current_position starts at B * T * process_rank so each rank reads a different stride. Combined with current_position += B * T * num_processes after every step, the 8 ranks tile the shard without overlap.
Hyperparameters: every number traced to a source
The single most important insight: papers specify training in tokens, not steps. Everything follows from that.
| Knob | Value | Source |
|---|---|---|
| Total batch (tokens / step) | 524288 = 2¹⁹ ≈ 0.5M | GPT-3 paper Table 2.1, “GPT-3 Small” |
max_steps | 19073 | 10_000_000_000 / 524288 = full pass over 10B tokens |
warmup_steps | 715 | 375_000_000 / 524288 (GPT-3 §2.3: linear warmup over 375M tokens) |
max_lr | 6e-4 | GPT-3 paper Table 2.1 |
min_lr | max_lr * 0.1 | GPT-3 §2.3: “cosine decay to 10% of max” |
betas | (0.9, 0.95) | GPT-3 §2.3 — note β₂=0.95, not Adam default 0.999 |
weight_decay | 0.1 | GPT-3 §2.3 |
| Wd applies to 2D+ params only | — | GPT-3 §2.3 |
clip_grad_norm | 1.0 | GPT-3 §2.3 |
n_layer, n_head, n_embd | 12, 12, 768 | GPT-2 paper Table 2 (124M row) |
block_size | 1024 | GPT-2 paper |
vocab_size (padded) | 50304 | Nearest multiple of 128 ≥ 50257, for tensor-core tile alignment. nanoGPT addition, not a paper. |
The provenance layers cleanly:
- GPT-3 paper → all training hyperparameters (LR, betas, weight decay, batch, schedule)
- GPT-2 paper → architecture (layers, heads, embd, pre-norm, ln_f, GELU)
- nanoGPT → implementation tricks (vocab pad to 50304, residual init rescale, fused AdamW)
Gradient accumulation and the global-batch math
Global batch size is fixed at 524288 tokens. With 8 GPUs and per-GPU B=16, T=1024:
per_gpu_tokens = B * T = 16 * 1024 = 16384
global_per_step = per_gpu * world = 16384 * 8 = 131072
grad_accum_steps = global / 524288 = 524288 / 131072 = 4
Every “macro step” = 4 forward+backward passes per GPU + 1 optimizer step + 1 all-reduce. Compute is dominated by the forward+backward; the all-reduce is bandwidth-bound but small compared to compute.
In code:
for micro_step in range(grad_accum_steps):
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / grad_accum_steps # critical
if ddp:
model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
loss.backward()
Two things to be careful about:
-
loss / grad_accum_stepsbecauseloss.backward()accumulates gradients with+=. Without dividing, four micro-steps give 4× the gradient — equivalent to multiplying the LR by 4. Subtle bug. -
require_backward_grad_syncis DDP’s escape hatch: gradient all-reduce normally fires on every.backward()call. With grad accumulation, we only want to sync on the last micro-step. The intermediate all-reduces would just waste bandwidth syncing partial gradients that will be added to again before the step. Setting the flag toFalsefor the first 3 micro-steps andTrueonly on the 4th cuts ~25% of network traffic with no correctness cost.
Speed: bf16, SDPA, TF32, vocab padding
On the 4 GB 1650 baseline at B=4, T=32: ~1080 tokens/sec at fp32. The Ampere/Hopper speedups don’t fit on that card. The cloud instance runs Hopper-tier compute and reclaims them all:
- bfloat16 autocast for forward+backward. Halves memory bandwidth and unlocks tensor-core throughput. bf16 has fp32’s exponent range, so unlike fp16 you don’t need a gradient scaler — the simple
torch.autocast(dtype=torch.bfloat16)context manager works directly. - Scaled-dot-product attention (
F.scaled_dot_product_attention) replaces the manual Q@K.T softmax. On Ampere+ this dispatches to FlashAttention 2 under the hood — fused kernel, no materialized(T, T)attention matrix, much better memory traffic. On the 1650 SDPA falls back to manual and gives ~0% speedup; on A100 it’s a real win. - TF32 matmuls via
torch.set_float32_matmul_precision('high'). Same fp32 storage but tensor-core compute. Free speedup, no accuracy cost on any benchmark I checked. - vocab_size padded 50257 → 50304. Multiple of 128 = aligns with tensor-core tile sizes. The extra rows are dead weight that never see gradient because no real token ID maps to them, but they make the matmul faster. ~150 KB of extra parameters, several % throughput. Worth it.
-
torch.compileoff for this run because HellaSwag eval has variable shapes (each example has its ownmax_len) and compile recompiles per shape — 10042 recompiles is catastrophic. Without HellaSwag in the loop, flip compile back on and reclaim ~20-30%.
DDP wraps after compile (when compile is on):
model = GPT(GPTConfig(vocab_size=50304))
model.to(device)
if use_compile:
model = torch.compile(model)
if ddp:
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model
torch.compile(DDP(model)) would try to trace DDP’s gradient-sync wrappers, which isn’t real compute. Compile-first, DDP-second.
The optimizer: weight decay split
Following GPT-3 §2.3, weight decay applies to “weight” matrices (2D and higher) but not to biases or LayerNorm γ/β (1D parameters). Implementation:
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad}
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0},
]
use_fused = device_type == 'cuda'
return torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas,
eps=1e-8, fused=use_fused)
fused=True calls a single CUDA kernel for the AdamW update instead of dispatching per-parameter. Trivially faster, no accuracy impact.
LR schedule: warmup + cosine + floor
def get_lr(it):
if it < warmup_steps:
return max_lr * (it + 1) / warmup_steps
if it > max_steps:
return min_lr
decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
Linear ramp from 0 to max_lr over warmup_steps. Then cosine decay to min_lr = 0.1 * max_lr. The cosine starts at coeff=1 (full max_lr) and ends at coeff=0 (full min_lr). At step 7240, decay_ratio = (7240 - 715) / (19073 - 715) ≈ 0.355 and the LR is roughly 0.1 + 0.85 * cos(π·0.355) ≈ 0.5 * max_lr. Mid-training, half max.
HellaSwag inline eval
HellaSwag is a 4-way multiple-choice commonsense benchmark. Each example: a context + 4 candidate endings, pick the most plausible. Scoring: for each candidate, average the LM’s per-token loss over just the ending tokens. The candidate with lowest loss is the model’s pick.
def get_most_likely_row(tokens, mask, logits):
shift_logits = logits[..., :-1, :].contiguous()
shift_tokens = tokens[..., 1:].contiguous()
shift_losses = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_tokens.view(-1),
reduction='none',
).view(tokens.size(0), -1)
shift_mask = mask[..., 1:].contiguous()
masked_losses = shift_losses * shift_mask
avg_loss = masked_losses.sum(dim=1) / shift_mask.sum(dim=1)
return avg_loss.argmin().item()
10042 examples, sharded across the 8 ranks. The accuracies all-reduce at the end of each eval block.
Random chance = 25%. OpenAI’s released GPT-2 (124M) baseline = 29.45%. Our reproduction = 26.99%.
Training memory: optimizer state is most of it
| What lives in GPU memory | Training | Inference |
|---|---|---|
| Model weights (124M × 4B) | ~500 MB | ~500 MB |
| Gradients (one per param) | ~500 MB | — |
Adam m running mean | ~500 MB | — |
Adam v running variance | ~500 MB | — |
| Activations for backward | huge, ∝ B×T | — |
| Forward activation buffers | yes | yes (ephemeral) |
| Total | 2-4 GB minimum | ~600 MB |
The optimizer state is what kills training on small cards. A 4 GB GTX 1650 can comfortably infer 124M at fp16 but can’t comfortably train it. AdamW needs 3× the model size in extra state, before you’ve allocated a single activation.
Cost ledger
| Phase | Time | Cost |
|---|---|---|
| FineWeb-Edu tokenization (multiprocess on 8× A100 inst.) | ~30 min | ~$8 |
| Training: 19,073 steps × ~472 ms/step | ~2.5 hours | ~$40 |
| Setup, idle, HF download throttles, debugging | ~1 hour | ~$13 |
| Total | ~4 hours | $61.23 |
8× A100 SXM4 on Lambda was $15.92/hr at the time. Every minute typing slowly costs real money — have the next command ready before SSH’ing in.
tmux new -s train keeps the training run alive across SSH drops. Doesn’t save money, but the alternative is a $40 training run dying because your laptop went to sleep.
What the trained model produces
Sample completions at val_loss 3.40 with temperature=1.0, top_k=50, prompt "Hello, I'm a language model,":
- Coherent prose about teachers, students, classrooms. Grammatical for 3 sentences.
- Rambling but grammatical paragraph on social media and jobs.
- Repetition loop:
"I'm a language model. I'm a language model. I'm a language model." - Has dialogue with quote marks, sentence-level coherence.
Diagnosis: working LM with sentence-level coherence and occasional repetition loops — exactly the failure mode of a 124M model at val_loss 3.4. At 3.1 the loops fade. With temperature=0.7, top_k=20 the output is cleaner but less creative.
What the model actually is
124,475,904 floating-point numbers plus ~200 lines of Python that combines them. The checkpoint file is those numbers plus a kilobyte of config. Inference loads the numbers and runs the architecture forward.
| Tensor | Shape | Count |
|---|---|---|
wte.weight (tied with lm_head) | 50304 × 768 | 38,633,472 |
wpe.weight | 1024 × 768 | 786,432 |
| 12 × Block | each ~7.09M | ~85,054,464 |
ln_f (γ + β) | 768 × 2 | 1,536 |
| Total | ~124,475,904 |
Random init and trained model have the same 124M numbers in identical shapes. The only difference is the numerical values. Training is 19,073 nudges of lr × gradient × (-1) applied to every number. No single number means anything. Meaning emerges from the collective behavior of all of them running through the matmuls.
What would close the remaining gap
- Train 5K more steps on 8× A100 to hit OpenAI val loss exactly. ~$80, gets val loss to ~3.29.
- Switch base dataset from FineWeb-Edu to OpenWebText. Educational text is high-quality but narrow. WebText is broader and gives better commonsense — closes the HellaSwag gap.
- Quantize for inference. int8 or int4 brings the inference footprint from 600 MB to ~150 MB, runs much faster on the 1650.
Code and plots: github.com/debtirthasaha/gpt2-124m-reproduction. The trained checkpoint (523 MB) is on Hugging Face at MR0b0t/gpt2-124m-reproduction.
Enjoy Reading This Article?
Here are some more articles you might like to read next: