The geometry of grokking
Grokking is the strangest replicable phenomenon in modern ML. A model that has memorized a task for thousands of steps suddenly, with no apparent reason, generalizes. I think the reason is geometric, and I think it tells us something important about what 'understanding' actually is.
Take a small transformer. Train it on modular arithmetic. Watch its loss curve. For thousands of steps, training loss drops while validation loss stays flat - the model is memorizing. Then, often abruptly, validation loss collapses. The model 'gets it'. This is grokking, and it has been replicated enough times that we can stop arguing about whether it is real.
What is going on. Why does the model memorize first and generalize later. Why does the flip happen so suddenly. And what, if anything, does this tell us about how generalization happens in larger systems.
Memorization and generalization are different geometries
A memorizing network and a generalizing network look different in their representations. Visualize the embeddings of the inputs - in the modular-arithmetic case, the integers - before and after grokking. Before: clouds. After: a clean circle. The structure of the task has been learned not in the weights as a giant lookup table, but in the geometry of the representations.
Generalization is a phase transition in geometry, not a continuous improvement in fit.
This is what makes grokking feel important to me. Before the flip, the network has solved the task one input at a time. After the flip, it has discovered a representation in which the task becomes simple. The actual computation that produces correct answers is now riding on a low-dimensional structure that mirrors the structure of the problem.
Why grokking takes time
There are good first-pass explanations. Weight decay slowly pushes toward simpler solutions. The flat memorizing solution and the sharp generalizing solution exist at different points in the loss landscape, and SGD wanders for a long time before falling into the better basin. Recent work has shown that grokking can be sped up dramatically when you bias optimization toward solutions with the right geometric structure.
What I think is most underrated is the role of curvature. The memorizing solution lives in a wide, flat region - many solutions look similar, so the model is not strongly pulled toward any specific one. The generalizing solution is a narrow valley - precise, low-energy, but small. Optimization needs to wander across the flat region long enough to find the narrow valley, and weight decay is the pressure that bends it that way.
What this means for the bigger picture
If grokking is a geometric phase transition - if generalization happens when representations align with the structure of the task - then the question of what large models 'understand' has a sharper formulation than people usually give it.
- What does the representational geometry of a large model look like for a given concept? Is it a low-dimensional structure that mirrors the concept's structure - a circle for periodicity, a graph for relational reasoning, a manifold for continuous attributes - or is it a sloppy lookup?
- When models 'know but cannot use' something, that often looks like the right geometric structure being there but not being well-routed into the residual stream. Generalization that exists locally but not globally.
- When we add new data and a model improves, are we shaping the geometry or just expanding the lookup. Large pretraining tends to produce surprisingly good geometry; targeted fine-tuning often does not.
Where I am taking this
Two threads. First, in Parallax, the workspace is in part designed to be a place where this kind of clean, shared geometry can crystallize - because content has to compete to enter it, only structure that pays its way survives. The hope is that the workspace is what separates a memorizing system from a generalizing one.
Second, I want to understand how these phase transitions scale. Grokking on toy tasks is well-studied. Grokking on language modeling is harder to see, partly because the loss curves do not have the same clean shape, but I suspect the same kind of representational reorganization happens at scale, just stretched out over millions of tokens. If we can see it, we can engineer for it.
There is something almost cosmic about grokking. A bunch of weights, a gradient, weight decay - simple machinery - and out of it falls a clean, structured representation of a mathematical fact. That should not be a footnote. It should be a clue.