Brain-Inspired Mechanisms for Sequence Modeling
What neuroscience can teach us about building efficient language models.
The human brain processes language on about 20 watts. The cortex uses even less, around 0.2 watts for actual computation. GPT-4's training run probably consumed something like 50 megawatts. That's a gap of roughly six orders of magnitude.
Now, comparing biological neurons to GPUs is tricky. They're different substrates doing different things. But six orders of magnitude is a lot. It suggests we might be doing something fundamentally wasteful.
I've been reading neuroscience papers to understand how the brain achieves this efficiency. This post summarizes what I found and what it might mean for neural architecture design.
Why is the Brain So Efficient?
Three factors stand out. First, communication costs about 35x more than computation in biological neurons. The brain evolved to minimize communication, not computation. This led to sparse coding, where only about 1% of neurons fire at any given time.
Second, the brain does "just enough" computation. Neural signals are noisy. Synaptic transmission fails about half the time. But the system is robust enough that it doesn't matter. There's no IEEE 754 floating point here, just good-enough analog computation.
Third, the brain adjusts energy use based on task demands. Easy predictions require less activity. Hard ones recruit more resources. Transformers do the opposite: they use the same computation for every token regardless of difficulty.
Sparse Coding
The numbers here are striking. Only about 1% of cortical neurons are active at any moment. Each memory is encoded by roughly 2-5% of neurons in the relevant area. Each neuron participates in about 5% of stored memories.
Why does sparsity help? For one, it maximizes storage capacity. If you have N neurons and k are active at a time, capacity scales as N²/k. Fewer active neurons means more distinct patterns you can store. It also minimizes interference because sparse patterns overlap less and are easier to distinguish.
The mechanism is simple: lateral inhibition. Active neurons suppress their neighbors. Only the most strongly activated neurons survive. This creates winner-take-all dynamics that enforce sparsity without explicit regularization.
Compare this to transformers. Every attention head, every FFN, activates every neuron for every token. Dense representations throughout. There's recent work on sparse attention and mixture of experts, but it's bolted on rather than fundamental to the architecture.
Predictive Coding
Here's an organizing principle the brain seems to use: silence equals success. Higher cortical areas send predictions down to lower areas. Lower areas compare those predictions to actual input. If they match, nothing propagates upward. Only prediction errors, the surprising information, get transmitted.
This is energy efficient because most of the time, predictions are correct. You don't need to signal "yes, the world is still there" every millisecond. You only need to signal when something unexpected happens.
It's also naturally hierarchical. Low-level areas predict sensory features. Higher areas predict more abstract patterns. Errors at each level indicate violations of expectations at that level of abstraction.
Transformers have no prediction mechanism. Every layer processes every input with equal intensity. There's no "this is expected, skip it" pathway. The attention mechanism asks "what's relevant?" but never "what's surprising?"
Two Memory Systems
The brain doesn't have one memory system. It has two that work together. The hippocampus learns fast. It stores new experiences immediately in sparse, separated patterns. The neocortex learns slow. It gradually extracts structure from repeated exposure.
During sleep, the hippocampus replays recent experiences to the neocortex. This "memory consolidation" transfers important information from fast storage to long-term distributed representations. Not everything gets consolidated, only experiences tagged as important during waking hours.
The hippocampus also does something clever: content-addressable retrieval. You give it a partial cue, and it completes the pattern. This is fundamentally different from attention, which enumerates all possibilities and computes weighted sums. Pattern completion is O(1) once the network settles; attention is O(n²).
Transformers have a single learning system: gradient descent on all parameters. No separation between fast episodic storage and slow structural learning. No consolidation phase. No content-addressable retrieval.
Temporal Hierarchy
Different brain regions operate at different timescales. Early auditory areas respond to features at ~10ms resolution. Language areas integrate over ~1 second. Prefrontal areas maintain context over tens of seconds.
This temporal hierarchy matches the hierarchical structure of language itself. Phonemes unfold over tens of milliseconds. Words over hundreds. Sentences over seconds. Discourse over minutes.
The mechanism involves recurrent connections with different time constants. NMDA receptors in higher areas have slower dynamics, naturally integrating over longer periods. You don't need to learn the timescale hierarchy because it's built into the biophysics.
Transformers treat all positions equally. Position 1 and position 1000 have the same computational weight. Positional encodings distinguish them, but there's no built-in sense that nearby tokens are more related than distant ones. The architecture has no temporal inductive bias.
Neural Oscillations
This one surprised me. Working memory capacity is about 4 items (not 7, as previously thought). The explanation involves neural oscillations.
The hippocampus generates a theta rhythm at 4-8 Hz. Within each theta cycle, gamma bursts (30-100 Hz) encode individual items. You can fit about 4-8 gamma cycles within one theta cycle. That's your working memory capacity, the items that can be activated within a single theta phase.
Position in the sequence is encoded by phase, not by learned positional embeddings. Item 1 fires early in the theta cycle. Item 4 fires late. The temporal structure itself carries positional information.
Transformers have no oscillatory dynamics. Position is explicit, learned, and arbitrary. There's no mechanism that naturally groups items into chunks or represents sequential order through timing.
What This Might Mean for Architecture Design
I see a few concrete directions worth exploring.
Sparse attention with winner-take-all. Instead of computing attention over all positions, compute it sparsely. For each query, only the top-k keys actually contribute. This reduces complexity from O(n²) to O(nk) and might match how the brain selectively attends.
Predictive filtering. Add a prediction mechanism where each layer predicts the next layer's input. Only propagate the error, not the full representation. This could dramatically reduce the information flow through the network for predictable inputs.
Hierarchical timescales. Instead of uniform transformer layers, use layers with different temporal receptive fields. Lower layers process short contexts quickly. Higher layers integrate over longer contexts more slowly. Let the architecture match the structure of language.
Content-addressable memory. Replace or augment attention with Hopfield-like pattern completion. Store key-value pairs in an associative memory that retrieves via similarity rather than enumeration. This is closer to how the hippocampus works.
The Honest Caveats
I want to be clear about what's risky here. Nobody has successfully scaled these ideas to competitive language modeling. Transformers work well partly because they parallelize efficiently on GPUs. Brain-inspired mechanisms might not parallelize at all.
"Biologically plausible" doesn't mean "computationally optimal." The brain evolved for survival in a physical world, not for predicting the next token. Its architecture is constrained by evolution, development, and the physics of biological tissue.
There's also the Bitter Lesson. Richard Sutton argues that general methods plus scale always beat human-designed features. Maybe transformers with more compute will always beat clever brain-inspired architectures with less.
But I'm not sure the Bitter Lesson applies indefinitely. Energy constraints are real. We can't scale to exawatts. At some point, efficiency must matter. The brain proves that efficient sequence processing is possible. The question is whether we can translate its principles into practical algorithms.
Where I'm Going Next
I'm starting with the simplest intervention: sparse attention. Take a standard transformer, replace softmax attention with top-k hard attention, and measure the quality/efficiency tradeoff. If 5% of positions carry 95% of the information, we should be able to prove it.
After that, hierarchical predictive coding. Build a two-layer predictive network where layer 2 predicts layer 1's output and only receives errors. See if the sparsity of error signals matches the theoretical predictions.
The goal isn't to match transformers immediately. It's to find a fundamentally different path that could scale better as compute budgets grow. The brain shows it's possible. Now we need to figure out how.