Nilesh Sarkar / Research

Grokking for Under-Sampled Datasets · Independent Research

Overview

This is a deep dive into an independent research thread of mine: can grokking-style training help a model generalize when the dataset is small? Small labelled datasets are the norm in domains like medical imaging, and models trained on them tend to latch onto easy spurious signals instead of the real task. The goal of this thread is to find out whether grokking-favorable optimization can be turned into a usable framework for squeezing generalization out of under-sampled data.

The work below is the first study in that direction. It is written up as a workshop-style paper, Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training. Rather than judging models by accuracy alone, it opens them up with four interventions and asks how the spurious shortcut is actually stored inside the network.

Status: Exploratory and ongoing. This is one slice of a larger effort - the sample sizes are small, none of the intervention results are statistically significant yet, and the training regimes differ along several axes at once, so nothing here is a finished claim. It is a first pass: a reusable suite of interventions, a clear negative result on the accuracy question, and a representational difference worth chasing.

Why Under-Sampled Datasets Fail

When a network is trained on patches of stained tissue collected from several hospitals, it routinely learns to read the hospital's stain colour instead of the tumour. Stain colour is a high-variance, easy-to-fit signal that happens to correlate with the label in the training set; tumour morphology is a lower-variance, harder-to-fit signal. The network takes the easy path. This is the textbook shortcut learning failure, and it is exactly why a model that scores well in-distribution can collapse the moment it sees a new hospital.

Smaller datasets make this worse: with fewer examples there is less pressure to find the real signal, and the shortcut is comparatively even easier to fit. The open question this thread cares about is not just that the shortcut is learned, but how it is stored - as a tidy low-dimensional direction, or smeared across many features - because that geometry decides whether the shortcut can be found and removed.

Grokking, and the Conjecture

Grokking is the phenomenon where a heavily regularized model, long after it has memorized its training set, suddenly switches from a memorizing solution to a generalizing one - test accuracy jumps after a long plateau. It was first reported on small arithmetic tasks and later extended to vision models.

The conjecture this study tests: if that switch could be induced on purpose, a small dataset might start to behave like a larger one. More precisely - even if grokking never produces a sustained accuracy jump, grokking-favorable optimization might still compress the shortcut into a more structured, lower-dimensional representation. The experiment does not assume grokking works; it tests what grokking-favorable training actually does to a model's internals.

The Experimental Setup

The testbed is Camelyon17, a cancer-detection benchmark of H&E-stained tissue patches from five hospitals. Training uses three hospitals; an in-distribution validation set comes from those same three; and a fourth, unseen hospital is held out as a true out-of-distribution (OOD) test. To force the under-sampled regime, training sets are deliberately shrunk to 300-1000 images.

Eleven ResNet-18 models were trained from scratch (no pretraining, 96×96 input, 3000 epochs, checkpointed throughout) across two regimes:

Grokking-favorable High weight decay, expanded initialization scale, and Grokfast gradient-EMA filtering - the standard recipe associated with grokking-like training dynamics.
Standard Ordinary weight decay, default initialization, no Grokfast - a plain cross-entropy baseline trained on the same data.
A confound, stated up front. The two regimes differ along three hyperparameter axes at once - weight decay, initialization scale, and Grokfast EMA. Single-axis ablations have not been run, so no cross-regime difference here can be pinned on grokking dynamics specifically; high weight decay alone could explain much of it. Every comparison below should be read against that limit.

Result 1 - Universal Ungrokking

The headline result is a clean negative: no grokking happened. Across all eleven runs, OOD accuracy on the held-out hospital peaked early and then decayed for the rest of training - the exact opposite of grokking's delayed jump. This pattern is consistent enough to name: ungrokking. On raw accuracy, the grokking-favorable regime did not beat the standard baseline; if anything, standard was slightly ahead.

Per-seed out-of-distribution accuracy trajectories at 300 training images, every run peaking early then decaying.
Ungrokking. Per-seed OOD accuracy over training at 300 training images. Every run peaks in the first quarter of training, then decays toward chance - no delayed generalization transition appears in any run.

The same shape holds at every dataset size tested. Plotting peak OOD accuracy against training-set size shows the conventional "sustained grokking" threshold is never crossed - the critical dataset size for this architecture and task, if it exists at all, is larger than anything tested here.

Peak out-of-distribution accuracy versus training-set size, never crossing the sustained-grokking threshold.
Critical fraction. Peak OOD accuracy versus training-set size for grokking-favorable runs. The dashed line is the conventional sustained-grokking threshold; no tested dataset size reaches it.

Looking Inside: Four Interventions

Accuracy alone says the two regimes behave the same. To see whether they store the shortcut differently, every checkpoint is probed with one correlational diagnostic and three causal interventions, all operating on the network's pooled feature representation:

Result 2 - A Representational Difference Under Intervention

Even though both regimes generalize about equally badly, the interventions tell a more interesting story. The clearest signal comes from Activation Steering: pushing features along the hospital direction produces a consistent, monotonic response in 4 of 5 grokking-favorable runs, versus only 1 of 3 standard runs. Pushing against the direction tends to improve OOD accuracy for grokking-favorable models; for standard models, pushing either way mostly just hurts.

Multi-seed subspace-ablation and activation-steering curves for the grokking-favorable and standard regimes.
Intervention response. Subspace ablation (left) and activation steering (right), averaged across seeds. Under steering, the grokking-favorable runs lean consistently one way along the hospital direction; the standard runs peak in the middle and fall off either side - the directional difference the accuracy numbers hide.

The per-run view shows the same contrast without the averaging. Subspace ablation degrades OOD accuracy in both regimes by similar amounts - both still rely on hospital-correlated information - while the steering curves separate by regime.

Per-run subspace-ablation and activation-steering trajectories, grokking-favorable versus standard.
Per-run interventions. Subspace ablation (top) and activation steering (bottom), grokking-favorable on the left, standard on the right. Each line is a single training run.

Targeted Neuron Ablation corroborates this. Zeroing the most hospital-selective neurons separates from the random-neuron control in a narrow window of neuron counts - around 64 neurons, 3 of 5 grokking-favorable runs show the targeted neurons mattering more than random, against 0 of 3 standard runs - and the gap washes out once many neurons are removed.

Targeted versus random neuron ablation curves across the number of ablated neurons.
Targeted vs random neuron ablation. Effect on OOD accuracy as more neurons are switched off, for targeted shortcut neurons, random neurons, and tumour neurons. The targeted-vs-random gap is localized to small neuron counts.
Per-seed neuron-ablation effect sizes at a fixed number of ablated neurons.
Per-seed neuron ablation. The same comparison broken out by seed - effect sizes sit close to the random-baseline noise floor, which is why neuron ablation is treated as the noisiest of the three interventions.

Finally, Layer-wise Linear Probing shows hospital identity stays highly recoverable in the early layers (which encode stain, as expected) and drops in the deeper layers over training - a correlational picture that is consistent with, but does not on its own prove, the intervention results.

Layer-wise linear-probe heatmaps for hospital and tumour recoverability across ResNet stages and training.
Layer-wise probing. Hospital-probe and tumour-probe accuracy across six ResNet stages over training. Deep-layer hospital recoverability fades while tumour information is retained.

What It Means - Concentration, Not Elimination

The most parsimonious reading of the four probes is shortcut concentration, not shortcut elimination. Subspace ablation damages OOD accuracy in both regimes, so the model's head is still using hospital-correlated information either way - the shortcut is not gone. But under grokking-favorable training, that information appears organized into a more directionally consistent, lower-dimensional, more steerable form.

In other words: heavy regularization seems to compress the spurious signal into a tidier, more interventionally exposed place rather than removing the model's dependence on it. That is a lead worth following - a concentrated shortcut is one you can find and act on - but on its own it does not improve generalization, and in this study OOD accuracy still collapsed.

Why It Didn't Grok - Mechanistic Analysis

A clean negative result still needs an explanation. With the full training record in hand, there is no need to speculate - the trajectory itself contains the diagnosis, and it points overwhelmingly at a single cause.

The smoking gun: the weight norm grows the wrong way

Every mechanistic account of grokking asks for the same thing: weight decay must shrink the weight norm during the post-memorization plateau, until the network slides into the low-norm region where the generalizing solution is the efficient one. Our runs did the exact opposite - the weight norm climbed monotonically through the entire post-fit phase:

weight norm · grokking-favorable, n=1000
epoch 1 → 356  ·  epoch 350 (OOD peak) → 908  ·  final → 1470  (+313%)
standard regime · final → 813

The norm grows 4× across training; every grokking-favorable seed ends between 1457 and 1632. This is the single most important fact in the experimental record. The prerequisite mechanism for grokking - norm shrinkage - simply never happens, so no theory of grokking applies to this trajectory.

Why the norm grows: BatchNorm + weight decay + cross-entropy

ResNet-18 keeps a BatchNorm layer after every convolution, and that breaks the recipe in two places at once.

On the convolutional layers, weight decay stops being a regularizer. BatchNorm is scale-invariant - multiply a convolution's weights by any constant and the output is unchanged, because BatchNorm renormalizes anyway. So weight decay cannot shrink the function the network computes; instead it behaves like a boosted learning rate (the van Laarhoven / Hoffer equivalence). Our "50× weight decay" was, functionally, a 50× learning-rate boost on the bulk of the network - not norm shrinkage.

On the classifier head, cross-entropy grows the weights without bound. The final linear head is the one part of ResNet-18 with no BatchNorm after it - the only parameter group that is not scale-invariant. Once the 1000 training images are memorized, cross-entropy still has gradient: it always wants more-confident predictions, which means larger logits, which means a larger head. With nothing normalizing it, the head simply grows - and it is the real engine of the late-stage norm climb.

Grokfast then amplifies exactly the wrong direction. Grokfast assumes the slow gradient component is the generalizing one - true on modular arithmetic, where the slow signal is a circuit emerging from noise. Here, with training already memorized, the dominant slow direction is the head-growth direction, and Grokfast amplified it. Put together, the regime we labelled "grokking-favorable" was really a high-effective-learning-rate, amplified-confidence-growth regime - mechanically the opposite of what the grokking literature exploits, even though it looks similar in a hyperparameter table.

The standard regime ungrokks too - the cleanest control

The most revealing number is that the plain baseline ungrokks as well: all three standard runs also decay on held-out data, just with a smaller final weight norm (~800 versus ~1500). That makes the standard runs an internal control, and it says something important - ungrokking is not caused by the grokking-favorable knobs. It is a property of training a BatchNorm ResNet-18 with AdamW and cross-entropy to convergence on a hospital-shortcut task at small sample size. The grokking knobs do not create the failure; they sit further along the same dysfunctional trajectory and amplify it.

This reframes the whole comparison. "Grokking-favorable" versus "standard" was never a contrast between two qualitatively different training regimes - it was two points on one trajectory, separated by roughly 2× in final weight norm. Any difference the interventions find is a difference between two non-grokking regimes.

Secondary contributors

So what did the interventions actually measure?

The four interventions are internally coherent, and they do show a real difference - but not a grokking one. High weight decay on a BatchNorm network (as a higher effective learning rate) squeezes the model into a lower-rank, more head-aligned feature space: effective rank falls from about 110 to about 36. But the standard regime's rank also ends around 33-37 - roughly the same place. So the deep-layer probe drop, the slightly smaller ablation damage, the more monotonic steering, the more separable neurons are all consistent with one quantitative fact: a more concentrated, lower-rank representation that is equally shortcut-reliant in absolute terms. In one line: heavy weight decay on a BatchNorm CNN yields a more concentrated feature space; grokking was never the active ingredient.

The reframe. This result is not "grokking concentrates the shortcut." It is: a ResNet-18 with BatchNorm, trained with AdamW and cross-entropy on small-sample Camelyon17, ungrokks under every regularization regime tested - and the high-weight-decay, large-init, Grokfast recipe accelerates the ungrokking rather than reversing it, because BatchNorm prevents weight decay from doing what grokking requires. That claim is readable straight off the weight-norm curve, with no interpretability machinery needed - and it is more defensible than the steering result.

Limitations

This is exploratory work, and the limits matter as much as the result:

Stripped of the grokking framing, the defensible core is one sentence: certain heavily regularized training regimes appear to compress hospital-correlated information into a more steerable low-dimensional direction, and this representational concentration coexists with - but does not produce - sustained OOD generalization.

Status & What's Next

This is an early, exploratory thread, and it is still moving. The next round of experiments is still being decided - more, with results, is coming soon.

A shorter summary of this work also appears on the Interpretability Experiments page.

Related