Research@Erdős AI Lab

Nilesh Sarkar - Founding Researcher

Role Overview

We exist at the intersection of rigorous mathematics and unbounded curiosity. Erdős AI Lab is a collective of young researchers devoted to advancing the frontier of artificial intelligence - from continual learning and world models to brain-computer interfaces and neuromorphic computing. We believe the next decade of AI will not be won by scale alone, but by ideas. Sharp, original, uncompromising ideas. That is what Erdős is built for.

Our current work centres on Knowledge Distillation (KD) and teacher-student model research - making large, powerful models more efficient and accessible by transferring their knowledge to smaller, faster student networks without significant loss in performance. We develop new KD algorithms, analyse the transfer of representations, and build robust pipelines for training and evaluating teacher-student architectures at scale.

Key Projects & Contributions

Publications & Research

Geometric Limits of Knowledge Distillation: A Minimum-Width Theorem via Superposition Theory

Nilesh Sarkar et al.  |  Conference on Language Modeling (COLM 2026) - Under Review  |  arXiv Preprint

Read on arXiv

Abstract

Knowledge distillation (KD) compresses large teacher models into smaller, deployable student networks. However, student performance consistently saturates at a loss floor that persists regardless of training method, objective function, or hyperparameter tuning. We prove that this floor is geometric in origin, not an artefact of optimisation.

Core Insight: Superposition Creates Hard Limits

Modern neural networks pack far more learned features into their hidden layers than they have dimensions - a phenomenon called superposition. When a teacher with hidden width dT learns N > dT features by encoding them as near-orthogonal directions, a student with width dS < dT can faithfully capture at most dS × g(α) of those features, where g(α) is a geometric packing function determined by the interference tolerance α. Features beyond this capacity are irrecoverably lost, setting a hard floor on distillation fidelity.

The Minimum-Width Theorem

We formalise this as the Minimum-Width Theorem: for a student to achieve loss within ε of the teacher, it must have hidden width at least dmin = N / g(α(ε)). Below this width, no amount of training - longer schedules, better objectives, curriculum strategies - can close the gap. The bound is tight: we construct matching lower-bound examples using sparse-autoencoder (SAE) feature dictionaries extracted from trained transformers.

Experimental Validation

We validate the theory on transformer language models by training SAEs to extract superposed feature dictionaries at each layer. Experiments measure alive-neuron counts, L0 sparsity, feature importance distributions, and per-layer capacity utilisation. The observed loss floors align quantitatively with the geometric predictions, confirming that architecture width - not training recipe - is the binding constraint on distillation quality.

Practical Implications

This research shifts the focus of model compression from "how to train" to "how to design." By estimating a teacher's superposition ratio ($N/d_T$) using Sparse Autoencoders (SAEs), we can determine the smallest viable student architecture ($d_{min}$) before beginning compute-intensive training. This provides a diagnostic tool to visualize exactly which features are "lost" during distillation, bridging hardware efficiency with model transparency.

  • Architecture Benchmarking: If a student's width falls below $d_{min}$, further training compute is mathematically wasted.
  • Feature Prioritization: When width is constrained, we can use SAE importance distributions to ensure high-value features are preserved.
  • Interpretability Link: SAE-based feature extraction diagnoses capacity limits and reveals the specific representational collapse of the student.
Mathematical Framework: The Minimum Width Theorem

For a student model to achieve a loss within $\epsilon$ of its teacher, it must possess a minimum hidden width:

$d_{min} = \frac{N}{g(\alpha(\epsilon))}$
Where $N$ is the number of features in superposition and $g(\alpha)$ is the geometric packing function. Below this width, the student undergoes catastrophic representational collapse.

Key Experimental Results

Training Curves: Loss, Sparsity, Alive Neurons
Training Curves: Loss, L0 Sparsity, Alive Neurons

Reconstruction MSE, the L1 sparsity penalty, and the combined loss across all 24 SAE layers over 300 M training tokens. With L1 held loose at 3e-4 you get the densest activations (avg L0 around 6,184) and the best reconstruction. Each line is a layer, so the smooth simultaneous declines confirm no layer blew up.

Per-Layer Metrics Heatmap
Per-Layer Metrics Heatmap

A compact summary card with one row per layer and one column per metric (L0, alpha, g(alpha), MSE, and so on). Shade tells you the relative value within the column. This is the at-a-glance shape of the whole trial.

Feature Importance Distribution
Feature Importance Distribution

Feature-importance distribution from the earlier Erdős experiments. Tells you how attribution mass spreads across features rather than concentrating in a handful.

SAE All-Layers Summary Table
SAE All-Layers Summary Table

Per-layer summary table from the earlier experiments. The same kind of compact metric matrix that the SAE_Premium_Heatmap replaces in the newer work.

SAE Feature Analysis Heatmap
SAE Feature Analysis Heatmap

Earlier representation-learning experiment: the layer-by-feature heatmap rendered in a high-contrast cyberpunk style. Same shape of analysis as the later SAE heatmaps, just predating the KD framework that organises the new work.

Knowledge Distillation Superposition Theory Minimum-Width Theorem Sparse Autoencoders Representation Learning COLM 2026

SAE Experiments (KD paper, COLM 2026)

Sparse Autoencoders (SAEs) are the central interpretability tool for the current knowledge-distillation paper. The work has two halves: a toy-model validation of the KD minimum-width theorem, then a real-LM sweep training 24 SAEs (one per Pythia-410M layer) for 300 M tokens at three different sparsity regimes. The full gallery below shows every figure produced for each trial, with notes on which variants made it into the paper and which did not. Also mirrored on the Interpretability Experiments page.

Experiment 1 - Toy-Model Validation

Before any real LM, a controlled toy environment validates the KD minimum-width theorem: theory predicts the dictionary width d*S at which reconstruction loss bottoms out. All three figures here go into the paper.

Loss vs SAE width curves bottoming out at KD-predicted optimal widths.
fig1 - Loss-floor curves. Loss vs SAE width dS for each (α, F) config. Each curve bottoms out at theory's d*S. Picked as main paper result.

Each curve traces the reconstruction loss as the SAE dictionary width d_S grows, for a different (alpha, F) configuration. Theory predicts that loss hits a floor at the exact d*_S value, and the data agrees on every curve. This is the headline plot of the toy experiment because it shows the predicted minimum width is not just a tendency but a sharp transition.

All-configurations summary heatmap.
fig5 - All-configurations summary. (α, F) sweep, colour = how well theory predicts measured loss. The "everything works" plot.

All (alpha, F) configurations on one canvas. Rows and columns sweep the two parameters; cell colour tells you how closely measured loss matches what KD theory predicted. Bright cells mean a near-perfect match. The point is to show the prediction holds uniformly across the parameter grid, not just on a handful of nice curves.

Floor vs predicted d*_S scatter on y=x.
fig8 - Floor vs d*S. Measured floor vs theoretical optimum width. Points on y = x = proof. Picked as money plot.

Every point compares the measured loss floor of a run to its theoretical d*_S value. If theory and experiment agree, the points sit on the y = x diagonal. They do. This single plot collapses the entire toy experiment into one image that says the theorem holds.

Trial A - L1 = 3e-4 fixed (dense regime)

Loosest sparsity. avg L0 ≈ 6,184, α ≈ 0.81, MSE ≈ 0.0019, d*S ≈ 9,963. Useful as a high-reconstruction baseline, not the regime for monosemanticity claims.

Training curves across 24 SAE layers over 300M tokens.
Training curves. Reconstruction MSE, L1 penalty, combined loss for all 24 layers over 300 M tokens.

Reconstruction MSE, the L1 sparsity penalty, and the combined loss across all 24 SAE layers over 300 M training tokens. With L1 held loose at 3e-4 you get the densest activations (avg L0 around 6,184) and the best reconstruction. Each line is a layer, so the smooth simultaneous declines confirm no layer blew up.

Layer x metric heatmap.
Layer × metric heatmap. Rows = layers, columns = L0, α, g(α), MSE. At-a-glance shape of the whole run.

A compact summary card with one row per layer and one column per metric (L0, alpha, g(alpha), MSE, and so on). Shade tells you the relative value within the column. This is the at-a-glance shape of the whole trial.

The three KD quantities per layer.
α, g(α), d*S per layer. The three KD quantities that drive the theoretical prediction.

The three KD quantities plotted side by side per layer. Alpha measures sparsity, g(alpha) is the capacity function, d*_S is the predicted optimum dictionary width. These are the inputs theory needs to produce a width prediction.

Per-layer L0 vs reconstruction MSE.
L0 vs MSE per layer. Which layers are naturally denser/sparser.

L0 (average active features per token) on one axis, reconstruction MSE on the other. Layers that are naturally denser sit on one side, sparser layers on the other. Reveals which transformer depths need more dictionary capacity.

Predicted loss floor vs dictionary width.
Predicted loss floor vs dS. KD prediction across a sweep of dictionary widths.

For each layer we sweep the dictionary width d_S and ask theory: where does loss bottom out? This figure draws the predicted curve. Comparing it against the actual training loss is what tells you the prediction is right.

Trial B - L1 = 8e-5 fixed (paper-exact)

Tuned so the L1 penalty matches the paper-exact sum formula. Seven rendering variants of the same per-layer measurements - chosen / rejected based on where each surface excels.

Premium per-layer SAE metrics heatmap.
SAE_Premium_Heatmap. Clean palette, large fonts. Picked for the paper.

The clean paper-grade rendering of the L1 = 8e-5 paper-exact trial. Rows are layers, columns are SAE metrics. Annotated labels and a balanced palette are why this version was picked as the paper figure.

Compact professional rendering.
SAE_Professional_Heatmap. Compact slide/LinkedIn variant. Not picked for the paper - palette less expressive in print.

Same data as the premium heatmap, with a more compact and conservative palette. Built for LinkedIn cards and slide decks. Did not make the paper because the colour band is less expressive at print resolution.

Per-layer deep-dive with sidecar stats.
SAE_Layer_Analysis_Heatmap. Per-layer deep dive with sidecar stats. Used in the paper appendix - too wide for the body.

Same data again, but each layer gets a sidecar with its own mini statistics. Used in the paper appendix where horizontal space is not as tight as the main column.

Small-multiples view.
SAE_HeatmapVisual_All_Layers. Small-multiples view of every layer. Not picked - duplicates the Premium in less detail.

Small-multiples view: every layer's full metric vector tiled onto one page. Useful in show-me-everything talks where the audience wants to inspect each layer in isolation.

Gradient-style rendering.
SAE_Gradient_Heatmap. Smooth gradients in place of buckets. Not picked for the paper - prettier but harder to read exact values.

Smooth-gradient rendering of the same metrics. Looks great as a blog header, but the gradient hides exact bucket boundaries so it is harder to read off a precise number.

Cross-metric correlation heatmap.
SAE_Correlation_Heatmap. Which metrics move together across layers (e.g. α ↔ dead-feature count). Appendix-only.

Which SAE metrics move together across layers. For example, alpha tracks dead-feature count. Useful for theoretical commentary about what each metric is really measuring.

KPI summary card.
SAE_KPI_Summary. Single-card overview - mean α, mean L0, total dead features. Used as social / cover image, not in the paper body.

Single card with the headline numbers: mean alpha, mean L0, total dead features. Used as the social and cover image, not as a paper-body figure.

Trial C - L1 = 5e-4 adaptive (target L0 ≈ 150)

Adaptive controller converges on L0 ≈ 152, α ≈ 0.994 - very sparse. ~7k dead features per layer but the survivors are doing real monosemantic work. Picked for the paper comparison panel precisely because the controller gives a guaranteed L0 target across trials.

Composite per-layer dashboard.
Dashboard. Per-layer α, g(α), reconstruction loss, dead-feature count. Picked for the paper.

Dashboard view of the L1 = 5e-4 adaptive trial. Composite of alpha, g(alpha), reconstruction loss, dead-feature count arranged as a layer-by-metric grid. This is the figure the paper uses for the comparison panel because it carries hard numbers.

Metric landscape across depth.
Metric landscape. SAE properties across depth. Not picked - landscape is qualitative; the dashboard has hard numbers.

The same per-layer measurements rendered as a continuous landscape so you can read how each property evolves through depth. Qualitative but very quick to scan.

Why these three trials. Together they cover the corners of the sparsity-vs-reconstruction frontier the KD theory predicts: dense (3e-4), paper-exact (8e-5), very-sparse adaptive (5e-4 → L0 = 150). Other L1 settings (1e-4, 2e-4, 7e-4) were piloted but discarded - either redundant with an included trial or far outside the regime where KD applies.

Reproducibility. 144 checkpoints per trial (24 layers × 6 snapshots) on Hugging Face: colm-run-exp-2-t1 · colm-run-exp-2-t2 · colm-run-trial-2 · source on GitHub.

Technologies & Methodology

PyTorch TensorFlow CUDA Distributed Training Knowledge Distillation Teacher-Student Models Deep Learning