Nilesh Sarkar / Research

Interpretability Experiments - Erdős AI Lab

Overview

These experiments are conducted under Erdős AI Lab as part of an ongoing program in mechanistic interpretability. The goal is to look inside trained transformer language models - not just measure their outputs - and recover the internal computations they actually perform.

Current threads include sparse autoencoder (SAE) feature extraction across layers, attention-circuit analysis on small/medium models, and feature-attribution probes that connect specific behaviors back to identifiable internal structure.

Status: Active and ongoing - February 2026 to present. Now live below: the full SAE experiment gallery from the COLM 2026 knowledge-distillation paper - toy-model validation of the KD minimum-width theorem, plus three real-LM trials on Pythia-410M (24 layers × 6 token checkpoints each, published on Hugging Face). In active development: attention-circuit analysis on IOI / induction heads, feature-attribution probes on reasoning benchmarks, and cross-model generalisation tests (Pythia-410M → Pythia-1.4B → Pythia-2.8B).
Adjacent research thread (early). Exploring a framework for better generalisation on low-sample medical imaging without leaning on generative deepfake augmentation - i.e. squeezing more from small, real datasets instead of synthesising them. More to share once the formulation stabilises.

Themes

SAE Experiments

Sparse Autoencoders (SAEs) are the workhorse of modern mechanistic interpretability - overcomplete dictionaries trained on residual-stream activations, aiming for sparse, monosemantic features. Everything below ties to a knowledge-distillation paper currently under review at COLM 2026. The full image gallery (every variant + selection rationale) lives here.

Experiment 1 - Toy-Model Validation of the KD Theorem

Before any real LM, a controlled toy environment validates the KD minimum-width theorem: for a target sparsity, theory predicts the SAE dictionary width d*S at which reconstruction loss bottoms out. Toy data lets every quantity be computed in closed form, so we can check theory against ground truth without LM training noise. All three figures below appear in the paper.

Loss vs SAE width curves for each (alpha, F) configuration, each bottoming out at the KD-predicted optimal width.
fig1 - Loss-floor curves. Loss vs SAE width dS for each (α, F) configuration. Every curve bottoms out at the d*S predicted by KD theory. Picked as the main paper result.

Each curve traces the reconstruction loss as the SAE dictionary width d_S grows, for a different (alpha, F) configuration. Theory predicts that loss hits a floor at the exact d*_S value, and the data agrees on every curve. This is the headline plot of the toy experiment because it shows the predicted minimum width is not just a tendency but a sharp transition.

All-configurations summary heatmap with rows and columns sweeping over alpha and F.
fig5 - All-configurations summary heatmap. Rows/columns sweep (α, F); colour = how well theory predicts measured loss. The "everything works" plot - good for one-glance executive overview.

All (alpha, F) configurations on one canvas. Rows and columns sweep the two parameters; cell colour tells you how closely measured loss matches what KD theory predicted. Bright cells mean a near-perfect match. The point is to show the prediction holds uniformly across the parameter grid, not just on a handful of nice curves.

Scatter plot of measured loss floor versus KD-predicted optimal width, sitting on the y=x line.
fig8 - Floor vs d*S. Measured loss floor against the theoretical optimum width. Points on y = x is the proof that KD's prediction holds. Picked as paper's "money plot".

Every point compares the measured loss floor of a run to its theoretical d*_S value. If theory and experiment agree, the points sit on the y = x diagonal. They do. This single plot collapses the entire toy experiment into one image that says the theorem holds.

Experiment 2 - Real-LM SAEs on Pythia-410M

After the toy validation, the same predictions are applied to Pythia-410M residual-stream features. Each trial trains 24 SAEs (one per layer, dictionary F = 32,768) for 300 M tokens with checkpoints at 50 / 100 / 150 / 200 / 250 / 300 M tokens. Three trials sweep how the L1 sparsity coefficient is set:

L1 = 3e-4 (fixed) Loosest sparsity. avg L0 ≈ 6,184, α ≈ 0.81, ~98% alive features. Best reconstruction MSE ≈ 0.0019. d*S ≈ 9,963.
L1 = 8e-5 (fixed, paper-exact) Tuned so the L1 penalty matches the paper-exact sum formula. avg L0 ≈ 1,962, α ≈ 0.94, MSE ≈ 0.019, d*S ≈ 5,487.
L1 = 5e-4 (adaptive, L0=150) Adaptive controller hits target L0 ≈ 152, α ≈ 0.994 (very sparse). MSE ≈ 0.23, d*S ≈ 781 - the narrow monosemantic-leaning regime.

Trial A - L1 = 3e-4 fixed (dense regime)

Looser-sparsity end of the sweep: features stay active across many tokens, almost nothing dies. Useful as a high-reconstruction baseline against which the sparser trials can be judged, but not the regime for monosemanticity claims - each feature still encodes too many things at once.

Training curves across 24 SAE layers - reconstruction MSE, L1 penalty, and combined loss over 300M tokens.
Training curves - reconstruction MSE, L1 penalty, combined loss across all 24 layers over 300 M tokens. Each line is one layer.

Reconstruction MSE, the L1 sparsity penalty, and the combined loss across all 24 SAE layers over 300 M training tokens. With L1 held loose at 3e-4 you get the densest activations (avg L0 around 6,184) and the best reconstruction. Each line is a layer, so the smooth simultaneous declines confirm no layer blew up.

Layer x metric heatmap showing L0, alpha, g(alpha), MSE per layer.
Layer × metric heatmap. Rows = layers, columns = metrics (L0, α, g(α), MSE, …). At-a-glance shape of the whole run.

A compact summary card with one row per layer and one column per metric (L0, alpha, g(alpha), MSE, and so on). Shade tells you the relative value within the column. This is the at-a-glance shape of the whole trial.

The three KD quantities - alpha, g(alpha), d*_S - plotted per layer.
The three KD quantities. α (sparsity coefficient), g(α) (capacity function), and d*S (predicted optimum width) per layer. Shows the theoretical prediction this trial yields.

The three KD quantities plotted side by side per layer. Alpha measures sparsity, g(alpha) is the capacity function, d*_S is the predicted optimum dictionary width. These are the inputs theory needs to produce a width prediction.

Per-layer scatter of L0 against reconstruction MSE.
Sparsity vs reconstruction trade-off, per layer. L0 on one axis, MSE on the other - reveals which layers are naturally denser/sparser.

L0 (average active features per token) on one axis, reconstruction MSE on the other. Layers that are naturally denser sit on one side, sparser layers on the other. Reveals which transformer depths need more dictionary capacity.

Per-layer loss floor predictions as a function of dictionary width.
Predicted loss floor vs dS, per layer. KD theory predicts where reconstruction error bottoms out as the dictionary widens - this plot shows the prediction across a sweep of dS.

For each layer we sweep the dictionary width d_S and ask theory: where does loss bottom out? This figure draws the predicted curve. Comparing it against the actual training loss is what tells you the prediction is right.

Trial B - L1 = 8e-5 fixed (paper-exact)

Lowest L1 of the three trials, tuned so the L1 penalty matches the paper-exact sum formula used in the theoretical derivation. This is the trial whose numbers go into the paper's main tables. It also has the richest visual output: seven rendering variants of the same underlying per-layer measurements, generated to support different downstream surfaces (paper, blog, presentation, dark-mode site, social card).

Premium per-layer SAE metrics heatmap - the paper-figure rendering.
SAE_Premium_Heatmap. Clean palette, annotated metric labels, large fonts. Picked for the paper - best balance of density and legibility.

The clean paper-grade rendering of the L1 = 8e-5 paper-exact trial. Rows are layers, columns are SAE metrics. Annotated labels and a balanced palette are why this version was picked as the paper figure.

More compact conservative-palette rendering of the same data.
SAE_Professional_Heatmap. Compact LinkedIn / slide-deck variant. Not picked for the paper - colour palette is less expressive at print resolution. Kept for social posts.

Same data as the premium heatmap, with a more compact and conservative palette. Built for LinkedIn cards and slide decks. Did not make the paper because the colour band is less expressive at print resolution.

Per-layer deep-dive heatmap with per-layer mini-stats sidecar.
SAE_Layer_Analysis_Heatmap. Per-layer deep dive with per-layer mini-stats in the sidecar. Not picked for the paper - too wide to fit a single column. Used in the appendix.

Same data again, but each layer gets a sidecar with its own mini statistics. Used in the paper appendix where horizontal space is not as tight as the main column.

Compact small-multiples view of every layer's metric vector.
SAE_HeatmapVisual_All_Layers. Small-multiples view - every layer's full metric vector on one page. Not picked for the paper - duplicates the Premium heatmap in less detail; kept for "show me everything" presentations.

Small-multiples view: every layer's full metric vector tiled onto one page. Useful in show-me-everything talks where the audience wants to inspect each layer in isolation.

Gradient-style heatmap with smooth colour gradients instead of buckets.
SAE_Gradient_Heatmap. Smooth gradients in place of discrete buckets. Not picked for the paper - prettier but harder to read exact values. Used as blog hero.

Smooth-gradient rendering of the same metrics. Looks great as a blog header, but the gradient hides exact bucket boundaries so it is harder to read off a precise number.

Correlation heatmap showing which SAE metrics move together across layers.
SAE_Correlation_Heatmap. Which SAE metrics move together across layers (e.g. α ↔ dead-feature count). Not picked for the paper body - it's a derived view; included in the appendix as theoretical commentary.

Which SAE metrics move together across layers. For example, alpha tracks dead-feature count. Useful for theoretical commentary about what each metric is really measuring.

One-screen KPI card with headline numbers - mean alpha, mean L0, dead features.
SAE_KPI_Summary. Single-card overview - mean α, mean L0, total dead features, etc. Not picked for the paper - used as social/cover image and the abstract page of the public-facing arXiv PDF.

Single card with the headline numbers: mean alpha, mean L0, total dead features. Used as the social and cover image, not as a paper-body figure.

Trial C - L1 = 5e-4 adaptive (target L0 ≈ 150)

The L1 starts at 5e-4 and is adapted layer-by-layer by a proportional controller (rate 0.01) so the average active features per token converges on L0 ≈ 150. Picked for the paper specifically because the controller gives a guaranteed sparsity target - clean apples-to-apples comparison across the three regimes. Trade-off: ~7k dead features per layer (out of 32,768) but the surviving features are far more monosemantic.

Composite dashboard showing per-layer SAE metrics in dark theme.
Dashboard. Per-layer α, g(α), reconstruction loss, dead-feature count arranged as a layer-by-metric matrix. Picked for the paper comparison panel.

Dashboard view of the L1 = 5e-4 adaptive trial. Composite of alpha, g(alpha), reconstruction loss, dead-feature count arranged as a layer-by-metric grid. This is the figure the paper uses for the comparison panel because it carries hard numbers.

Metric landscape showing how SAE properties evolve through depth.
Metric landscape. The same per-layer measurements rendered as a continuous landscape - early/middle/late layers behave differently. Not picked for the paper - landscape view is qualitative; the dashboard carries the same content with hard numbers.

The same per-layer measurements rendered as a continuous landscape so you can read how each property evolves through depth. Qualitative but very quick to scan.

Why these three trials (and not more)?

The three trials are not a random ablation - they were chosen to cover the corners of the sparsity-vs-reconstruction frontier the KD theory predicts: dense (3e-4), paper-exact (8e-5), and very-sparse adaptive (5e-4 → L0=150). Other L1 settings (1e-4, 2e-4, 7e-4) were tried during pilots but discarded - either too close to an already-included trial (no new information) or so far from theory's regime that the KD prediction breaks down for unrelated reasons (e.g. dead-feature collapse).

Reproducibility

Each trial publishes 144 checkpoints (24 layers × 6 token snapshots: 50 / 100 / 150 / 200 / 250 / 300 M) to Hugging Face so the full training trajectory is recoverable:

Why It Matters

Capability is moving fast; understanding is not. Interpretability work is one of the few research directions that scales with capability rather than against it - bigger models give richer internal structure to study, and reliable mechanistic accounts of behavior are a prerequisite for trusting AI systems in high-stakes settings.

Related