Overview
This is a deep dive into an independent research thread of mine: can grokking-style training help a model generalize when the dataset is small? Small labelled datasets are the norm in domains like medical imaging, and models trained on them tend to latch onto easy spurious signals instead of the real task. The goal of this thread is to find out whether grokking-favorable optimization can be turned into a usable framework for squeezing generalization out of under-sampled data.
The work below is the first study in that direction. It is written up as a workshop-style paper, Interventional Analysis of Shortcut Geometry Under Grokking-Favorable Training. Rather than judging models by accuracy alone, it opens them up with four interventions and asks how the spurious shortcut is actually stored inside the network.
Why Under-Sampled Datasets Fail
When a network is trained on patches of stained tissue collected from several hospitals, it routinely learns to read the hospital's stain colour instead of the tumour. Stain colour is a high-variance, easy-to-fit signal that happens to correlate with the label in the training set; tumour morphology is a lower-variance, harder-to-fit signal. The network takes the easy path. This is the textbook shortcut learning failure, and it is exactly why a model that scores well in-distribution can collapse the moment it sees a new hospital.
Smaller datasets make this worse: with fewer examples there is less pressure to find the real signal, and the shortcut is comparatively even easier to fit. The open question this thread cares about is not just that the shortcut is learned, but how it is stored - as a tidy low-dimensional direction, or smeared across many features - because that geometry decides whether the shortcut can be found and removed.
Grokking, and the Conjecture
Grokking is the phenomenon where a heavily regularized model, long after it has memorized its training set, suddenly switches from a memorizing solution to a generalizing one - test accuracy jumps after a long plateau. It was first reported on small arithmetic tasks and later extended to vision models.
The conjecture this study tests: if that switch could be induced on purpose, a small dataset might start to behave like a larger one. More precisely - even if grokking never produces a sustained accuracy jump, grokking-favorable optimization might still compress the shortcut into a more structured, lower-dimensional representation. The experiment does not assume grokking works; it tests what grokking-favorable training actually does to a model's internals.
The Experimental Setup
The testbed is Camelyon17, a cancer-detection benchmark of H&E-stained tissue patches from five hospitals. Training uses three hospitals; an in-distribution validation set comes from those same three; and a fourth, unseen hospital is held out as a true out-of-distribution (OOD) test. To force the under-sampled regime, training sets are deliberately shrunk to 300-1000 images.
Eleven ResNet-18 models were trained from scratch (no pretraining, 96×96 input, 3000 epochs, checkpointed throughout) across two regimes:
Result 1 - Universal Ungrokking
The headline result is a clean negative: no grokking happened. Across all eleven runs, OOD accuracy on the held-out hospital peaked early and then decayed for the rest of training - the exact opposite of grokking's delayed jump. This pattern is consistent enough to name: ungrokking. On raw accuracy, the grokking-favorable regime did not beat the standard baseline; if anything, standard was slightly ahead.
The same shape holds at every dataset size tested. Plotting peak OOD accuracy against training-set size shows the conventional "sustained grokking" threshold is never crossed - the critical dataset size for this architecture and task, if it exists at all, is larger than anything tested here.
Looking Inside: Four Interventions
Accuracy alone says the two regimes behave the same. To see whether they store the shortcut differently, every checkpoint is probed with one correlational diagnostic and three causal interventions, all operating on the network's pooled feature representation:
- Layer-wise Linear Probing - measuring, layer by layer, how recoverable hospital identity and tumour identity are from the features. A correlational diagnostic, not a behavioural test.
- Subspace Ablation of the Hospital Shortcut - building the subspace that carries hospital information, projecting features orthogonal to it, and checking whether the model's predictions break.
- Activation Steering - identifying the dominant between-hospital direction and pushing features along it, then watching how the prediction moves. Unlike ablation, this adds or subtracts a direction rather than removing information.
- Targeted Neuron Ablation - switching off the most hospital-selective neurons, measured against a random-neuron control and a tumour-neuron control, to test whether the shortcut is localized.
Result 2 - A Representational Difference Under Intervention
Even though both regimes generalize about equally badly, the interventions tell a more interesting story. The clearest signal comes from Activation Steering: pushing features along the hospital direction produces a consistent, monotonic response in 4 of 5 grokking-favorable runs, versus only 1 of 3 standard runs. Pushing against the direction tends to improve OOD accuracy for grokking-favorable models; for standard models, pushing either way mostly just hurts.
The per-run view shows the same contrast without the averaging. Subspace ablation degrades OOD accuracy in both regimes by similar amounts - both still rely on hospital-correlated information - while the steering curves separate by regime.
Targeted Neuron Ablation corroborates this. Zeroing the most hospital-selective neurons separates from the random-neuron control in a narrow window of neuron counts - around 64 neurons, 3 of 5 grokking-favorable runs show the targeted neurons mattering more than random, against 0 of 3 standard runs - and the gap washes out once many neurons are removed.
Finally, Layer-wise Linear Probing shows hospital identity stays highly recoverable in the early layers (which encode stain, as expected) and drops in the deeper layers over training - a correlational picture that is consistent with, but does not on its own prove, the intervention results.
What It Means - Concentration, Not Elimination
The most parsimonious reading of the four probes is shortcut concentration, not shortcut elimination. Subspace ablation damages OOD accuracy in both regimes, so the model's head is still using hospital-correlated information either way - the shortcut is not gone. But under grokking-favorable training, that information appears organized into a more directionally consistent, lower-dimensional, more steerable form.
In other words: heavy regularization seems to compress the spurious signal into a tidier, more interventionally exposed place rather than removing the model's dependence on it. That is a lead worth following - a concentrated shortcut is one you can find and act on - but on its own it does not improve generalization, and in this study OOD accuracy still collapsed.
Why It Didn't Grok - Mechanistic Analysis
A clean negative result still needs an explanation. With the full training record in hand, there is no need to speculate - the trajectory itself contains the diagnosis, and it points overwhelmingly at a single cause.
The smoking gun: the weight norm grows the wrong way
Every mechanistic account of grokking asks for the same thing: weight decay must shrink the weight norm during the post-memorization plateau, until the network slides into the low-norm region where the generalizing solution is the efficient one. Our runs did the exact opposite - the weight norm climbed monotonically through the entire post-fit phase:
epoch 1 → 356 · epoch 350 (OOD peak) → 908 · final → 1470 (+313%)
standard regime · final → 813
The norm grows 4× across training; every grokking-favorable seed ends between 1457 and 1632. This is the single most important fact in the experimental record. The prerequisite mechanism for grokking - norm shrinkage - simply never happens, so no theory of grokking applies to this trajectory.
Why the norm grows: BatchNorm + weight decay + cross-entropy
ResNet-18 keeps a BatchNorm layer after every convolution, and that breaks the recipe in two places at once.
On the convolutional layers, weight decay stops being a regularizer. BatchNorm is scale-invariant - multiply a convolution's weights by any constant and the output is unchanged, because BatchNorm renormalizes anyway. So weight decay cannot shrink the function the network computes; instead it behaves like a boosted learning rate (the van Laarhoven / Hoffer equivalence). Our "50× weight decay" was, functionally, a 50× learning-rate boost on the bulk of the network - not norm shrinkage.
On the classifier head, cross-entropy grows the weights without bound. The final linear head is the one part of ResNet-18 with no BatchNorm after it - the only parameter group that is not scale-invariant. Once the 1000 training images are memorized, cross-entropy still has gradient: it always wants more-confident predictions, which means larger logits, which means a larger head. With nothing normalizing it, the head simply grows - and it is the real engine of the late-stage norm climb.
Grokfast then amplifies exactly the wrong direction. Grokfast assumes the slow gradient component is the generalizing one - true on modular arithmetic, where the slow signal is a circuit emerging from noise. Here, with training already memorized, the dominant slow direction is the head-growth direction, and Grokfast amplified it. Put together, the regime we labelled "grokking-favorable" was really a high-effective-learning-rate, amplified-confidence-growth regime - mechanically the opposite of what the grokking literature exploits, even though it looks similar in a hyperparameter table.
The standard regime ungrokks too - the cleanest control
The most revealing number is that the plain baseline ungrokks as well: all three standard runs also decay on held-out data, just with a smaller final weight norm (~800 versus ~1500). That makes the standard runs an internal control, and it says something important - ungrokking is not caused by the grokking-favorable knobs. It is a property of training a BatchNorm ResNet-18 with AdamW and cross-entropy to convergence on a hospital-shortcut task at small sample size. The grokking knobs do not create the failure; they sit further along the same dysfunctional trajectory and amplify it.
This reframes the whole comparison. "Grokking-favorable" versus "standard" was never a contrast between two qualitatively different training regimes - it was two points on one trajectory, separated by roughly 2× in final weight norm. Any difference the interventions find is a difference between two non-grokking regimes.
Secondary contributors
- The IRM collapse was a red herring. The invariance penalty dropped to ~10-13 early on - but because the loss hit zero (memorization), every gradient vanishes, so the penalty does too, for trivial reasons. And it was never in the training loss: there was no invariance pressure at any point.
- Init scale and loss function. The 4× init is far below published grokking recipes (closer to 100× for cross-entropy), and BatchNorm renormalizes weight scaling away regardless. Cross-entropy is also the harder loss to grok with - the literature favours MSE.
- Dataset size and training length. 1000 images is far below any plausible critical dataset size, and 3000 epochs is short. These would matter - but only if the optimization dynamics were pointed the right way. With the weight norm diverging, more data and more steps would not change the qualitative outcome.
So what did the interventions actually measure?
The four interventions are internally coherent, and they do show a real difference - but not a grokking one. High weight decay on a BatchNorm network (as a higher effective learning rate) squeezes the model into a lower-rank, more head-aligned feature space: effective rank falls from about 110 to about 36. But the standard regime's rank also ends around 33-37 - roughly the same place. So the deep-layer probe drop, the slightly smaller ablation damage, the more monotonic steering, the more separable neurons are all consistent with one quantitative fact: a more concentrated, lower-rank representation that is equally shortcut-reliant in absolute terms. In one line: heavy weight decay on a BatchNorm CNN yields a more concentrated feature space; grokking was never the active ingredient.
The reframe. This result is not "grokking concentrates the shortcut." It is: a ResNet-18 with BatchNorm, trained with AdamW and cross-entropy on small-sample Camelyon17, ungrokks under every regularization regime tested - and the high-weight-decay, large-init, Grokfast recipe accelerates the ungrokking rather than reversing it, because BatchNorm prevents weight decay from doing what grokking requires. That claim is readable straight off the weight-norm curve, with no interpretability machinery needed - and it is more defensible than the steering result.
Limitations
This is exploratory work, and the limits matter as much as the result:
- Not statistically significant. With 5 grokking-favorable and 3 standard seeds, none of the intervention tests reach the conventional significance bar. The steering result is the strongest, and even it only reaches the borderline range. The findings are reported as a direction-consistent observation across three intervention types, not as a proven effect.
- The three-axis confound. Grokking-favorable differs from standard in weight decay, initialization scale, and Grokfast EMA simultaneously. High weight decay alone is known to concentrate features into lower-dimensional directions - which would reproduce the steering pattern with no grokking mechanism involved.
- A perturbation-scale caveat. Steering is scaled per-regime, and high weight decay produces tighter feature distributions, so the two regimes receive different absolute perturbations at a fixed steering level. A fixed-norm steering control is needed to fully separate a real directional difference from this scale mismatch.
- One dataset, small architecture. Everything here is Camelyon17 with ResNet-18. Replication on other clinical benchmarks and larger backbones is open.
Stripped of the grokking framing, the defensible core is one sentence: certain heavily regularized training regimes appear to compress hospital-correlated information into a more steerable low-dimensional direction, and this representational concentration coexists with - but does not produce - sustained OOD generalization.
Status & What's Next
This is an early, exploratory thread, and it is still moving. The next round of experiments is still being decided - more, with results, is coming soon.