Neel Nanda - Mechanistic Interpretability: A Whirlwind Tour
Transcript
I'm going to be trying to give a whirlwind tour of: “What is mechanistic interpretability?” First slide: what is mechanistic interpretability? So the core hypothesis of the field is that models learn human-comprehensible algorithms. They contain structure that makes sense, and can be understood, but they have no incentive to make this legible to us.
They learn this structure because it is useful for getting loss on predicting the next token. And it is our job to learn how to reverse-engineer it, and how to make it legible. And in particular, how to achieve a rigorous understanding without tricking ourselves. There's a kind of analogy to how someone might take a compiled program binary and try to at least partially reverse-engineer it to source code.
And, I think that, in some circles, this is a pretty controversial claim. Machine learning is notoriously a completely bullshit field that makes no sense, and is just a massive pile of linear algebra. I want to try to convince you, in the first half of this talk, that mechanistic understanding is at least possible. And I'm going to do this by talking about a paper I wrote called Progress Measures for Grokking via Mechanistic Interpretability—the only paper I've ever written that led to fan art and t-shirts—where we were studying one-layer transformers that learned how to do addition mod 113. And spoiler, the algorithm fits on one slide.
Why was this interesting? Because of this phenomenon called grokking, where if you train a small model on certain algorithmic tasks on the same data again and again, it will initially memorize the data, but if you keep training on that same memorized training data, sometimes if you fiddle with your hyperparameters enough, it will abruptly generalize. Like, what? And my bet was that if we could reverse-engineer it, we could get insight into what was happening.
How do you even get traction on reverse engineering a system? The first hint is that if you look at the embedding of this model, the lookup table mapping the numbers from 0 to 112 to vectors, and do dimensionality reduction on them, it's really periodic. You just get circles everywhere. Here are two things from the literature before I started this work. If you just look anywhere in the model, you get these beautifully periodic patterns. Like, here's a heat map of a neuron activation across all pairs of inputs. There's clearly some structure, and it turns out that the key idea is, “well, this looks periodic, let's apply a Fourier transform and see what happens.”
And it turns out that if you do that, and look at the norms of your components, it's incredibly sparse. There's just six frequencies that matter, and every other frequency is basically negligible. At initialization, it looks nothing like this, because a Fourier transform is just a rotation, and a randomly initialized matrix doesn't look sparse in a basis, because why would it?
And, to cut a long story short, we were able to use this to actually decode the algorithm being used. The model learned that modular addition was equivalent to composing rotations around the unit circle. Because composition adds the angles, and the circley-ness gets you modularity for free.
The embedding table converted numbers to trig functions, parameterizing this. It used MLP layers to multiply them together and used compound angle formulae. And then it did some slight voodoo at the end to convert this back into the actual answer.
How did we know this is what was going on? We had a couple of different lines of evidence I've already discussed.
Everything's really periodic. This isn't proof, but it's a really suggestive hint. A second is: sometimes you can just read off steps of the algorithm from the model weights. A third is: this algorithm makes strong predictions about the mathematical form of neurons. Specifically that they'll be in a rank 6 subspace of 12,000 dimensional space. And this is a pretty fantastic approximation.
And finally, if you ablate things the algorithm says should matter, you massively degrade performance, while if you ablate everything else, you actually somewhat improve.
And I don't necessarily expect people to have followed the exact details of that, but the key takeaway I want you to get from this is that at least in some contexts, mechanistic understanding is actually possible. I never told this model, “learn this trig based algorithm.” I just told it, get good cross-entropy loss. But it learned this rich, emergent structure, and by staring at the weights, I was able to reverse-engineer the weird-ass stuff going on inside.
But this is a workshop on AI safety and alignment. Interpretability is a really fun nerd snipe. But why was I invited to talk here of all places? To me, one of the crucial motivations of interpretability is that we don't know how to understand a model's cognition. And this means we can't answer questions like, is it aligned, or is it telling us what we want to hear? And I think interpretability is one of a small handful of tools that could let us answer this question.
And I think we're increasingly in a world where we're going to need to grapple with questions like this. Here's an example from the GPT-4 system card, where it was told to convince a TaskRabbit worker to solve a CAPTCHA for it. It instrumentally realized it would be useful to deceive the worker to achieve its goals, and it successfully did. I'm not scared of GPT-4. This was done via a secret scratchpad that we can just read. But I think this is a pretty clear proof of concept that this is going to start becoming a real issue sooner rather than later.
Often the best way to achieve loss, if you haven't set up your loss right, is to say things that are not true, is to mislead or deceive users or the researchers operating them. And if we don't know how to tell the difference, for a sufficiently capable system, then what are we going to do?
We don't know how to train for what we can't measure. And if we can't measure cognition, then not only can we not check whether our techniques that weren't based on interpretability, worked for preventing things like deception. We don't even know how to check whether we succeeded. I think it's very plausible we end up in a world where we make AGI, and it seems aligned, but we have no idea how to tell the difference between a truly aligned AGI and one that's just biding its time.
And I neither think that “YOLO releases into the world is a tenable solution,” nor that “let's just never release AGI” is a politically realistic solution. We need to do better. And I'm not claiming interpretability is the only such path, but I think It's one of the ones with potential. I also think that getting better at interoperability could help the rest of our alignment and safety work in a bunch of ways, just because the more you understand about the system you're working with, the easier it is to see if your techniques work, the easier it is to iterate, the easier it is to debug weird, mysterious things going wrong, like why did my search engine gaslight my users?
We're going to spend the rest of this talk talking about an approach that's currently my team's focus, that I'm particularly excited about, which is sparse autoencoders, a tool for resolving the phenomena of superposition. So, a key goal in mech interp is to decompose models to units - things that are independently meaningful and can be analyzed in isolation, but compose together to create rich structure.
This is crucial because models are big, and we don't know how to think about big high-dimensional objects, but we know how to think about small objects. The question is, how do we do this? A hope, in the early field, was that neurons would be interpretable. They would correspond to concepts or features. Neuron being an element of the standard basis of an activation. This was less crazy than it sounds. Activation functions like ReLU's and GELU's give the model an incentive to use that basis.
But it turns out that this wasn't good enough. There's this problem with polysemanticity: neurons fire for multiple seemingly unrelated things, like this “poetry, card games, and poker” neuron. What's going on? Our current best guess is this hypothesis called Superposition. That models actually want to represent more features than they have dimensions. These can't be orthogonal to each other, so they can't be extracted perfectly, but they can be almost orthogonal and extracted with a bit of interference. And this lets the model represent many more things, which lets it achieve more performance. This gives it an incentive to do things like polysemanticity, even if this also introduces error. Because if you can represent many more ideas than you have dimensions, you can get a lot more done.
The underlying intuition here is this hypothesis, which has a fair amount of evidence but has not yet been proven, called the Linear Representation Hypothesis, or the Word2Vec Hypothesis, that activations are sparse linear combinations of meaningful feature vectors. What this means is that there's some dictionary of concepts the model knows about, e.g. gender and royalty. Each one has a direction associated with it, and on a given input, some of these concepts are relevant. They get some score, and the activation is roughly a linear combination of those directions weighted by how important they are. E.g. king is the male direction plus the royalty direction, -ish. Sparse because most concepts just aren't relevant on most inputs. Like: royalty is irrelevant to most things. So most of the feature scores are zero.
Sparse autoencoders are a technique to both learn this dictionary and learn this sparse vector of coefficients. The key idea is to train a wide autoencoder to reconstruct the input activations. Just a simple two layer network with ReLU activations in the middle. The hope is that the decoder, just a matrix, is this dictionary of meaningful vectors, that each latent in the autoencoder is a different concept, and that the activations of these latents on given inputs tell us the sparse vector of coefficients. And to make it sparse, we train this on a model's activations, like a specialized microscope zooming in on those activations and expanding them. And we use L1 - or, there is a bunch of other related metrics by now - on the activations in the middle to regularize them to be sparse. And the hope is that if there is an interpretable sparse decomposition, that just optimizing something to be sparse without optimizing it to be interpretable will stumble across that interpretable decomposition.
And this technique has been scaled up to frontier models like GPT-4 and Claude Sonnet Medium 3 and it kind of works. You can find abstract multimodal features like this unsafe code feature that activates on unsafe code but also pictures like “This website is unsafe. Do you want to proceed?” Or “turn off safe browsing”, which I don't know, I find pretty wild. These features are also causally meaningful. You can turn them on and it will steer the model's output. Like I'm sure most people here have seen Golden Gate Claude. You turn on the Golden Gate Bridge feature, you ask Claude what its physical form is, and it says “my physical form is the iconic bridge itself.” And I think this is really good validation that these SAEs are finding some real structure in the model that is part of how it is doing this computation. I do not think this is yet proven, but I feel pretty optimistic about this technique.
And you can even use this as a hook to start to decode entire algorithms. There was this great paper from Sam Marks, who was then in David's lab, trying to find algorithms inside a tiny model in terms of the sparse autoencoder basis. And this even let them do debugging, like they trained a probe for what profession someone has that picks up on gender, but they were able to debias the probe by deleting gender-related features. Which I think was just super cool.
But there's still a lot of big problems remaining. I'm really excited to see people see whether SAEs are useful in real world tasks, like hallucinations, jailbreaks, debugging. I want to learn how to make better SAEs. We really need better evaluations for “Is our SAE actually doing what we want? Is it any good?” I want to be able to scalably use them to find algorithms. I want to red team them. I want to really interrogate, “Are they doing what we think they're doing?” before we go too hard on this being the key thing we're going to be using. One opportunity that might be of interest is that my team is about to release the weights of hundreds of SAEs on some Gemma 2 models.
If anyone in the audience thinks they could productively use these for their out-of-lab research, please reach out. We've just opened early access, and I'd love to support more of this kind of work outside of labs, because these things are a pain in the ass to train. And, yeah, if you want to try playing with them, there's this really fun website called Neuronpedia you can check out.
Q&A
Q: The first question is from Lauro Langosco. The main motivating hypothesis for Interp is that neural network structure is understandable. It looks like this hypothesis is true in some cases, but it might not be true in general. Have you ever tried to falsify it, or do you have a take on when it might fail?
A: It's really hard to falsify. Like, how do you know if it's not interpretable or if it's a skill issue? We try interpreting things. We see what happens. Sometimes it goes better, sometimes it goes worse. But, I think it's just really hard to conclusively show that no matter what we do, we will fail to understand something.
For example superposition means that trying to understand things in the neuron basis doesn't work very well. I had a failed project to try to understand factual recall, which we concluded “Oh God, superposition, that's a nightmare.” But maybe with sparse autoencoders we can make a lot more headway. I don't know. I don't have good ideas for how to falsify it. We just try stuff and see how far we get. And other people try other safety approaches and see how far we get. And by the time we get to AGI, we see what seems most likely to help us not all die.
Q: The next question is from Nandi Schoots. Which current widely accepted hypotheses or assumptions or results in mech interp do you think are most likely to be found to be false in the future?
A: So that's a kind of gnarly question because there's a big difference between the kind of hypotheses people implicitly act as though they're following or imply they believe and what they actually believe when you argue with them.
I think it's extremely plausible that the linear representation hypothesis is only partially true. And that there's some weird non linear stuff hiding beneath. But I don't think I've seen evidence that there's important nonlinear stuff hiding beneath. And when you argue with most researchers, they agree with that.
It's possible that this is just not actually possible. Or that it's not gonna scale, or it's not gonna be good enough to matter, but most competent interoperability people will agree with that if you discuss it with them. I'm pretty skeptical of people who are aiming for formal verification-tier guarantees and that level of rigor from interpretability. I think it's more of an enabler and source of additional evidence than a kind of conclusive thing that can achieve proof level guarantees. But opinions differ. And I wouldn't say that's a mainstream view in the field. It's just a view of some people.
Q: Our next question is from Adam Gleave. When we have fully understood circuits, for example IOI, or indirect object identification, the way the networks function often seem very messy. How optimistic are you that when we do fully understand frontier models, there'll be something a human could effectively reason about? As in they will not be so messy that we can't reason about them at all. Even if mech interp provided an explanation.
A: I don't really know. I think this is an open empirical question and neither outcome would massively surprise me. I do expect there to be broad strokes that you can coherently reason about. I can reason about IOI. I also expect that models will just learn an enormous stack of heuristics that correlate with the right answer. And that these will also be part of this. And I think you need to decide on what level of rigor you're aiming for.
If you're aiming to explain 99.9% of the model's performance, there's probably going to be a long tail of random crap you need to care about. But if you're just trying to get somewhere, and be able to understand the broad strokes of does this model have goals, what are the kinds of things I would push it towards, or what's the 80:20 of why it output X rather than Y, you can often find things to reason about. It also depends on how much you care about understanding the circuits, versus just understanding some of the key intermediate representations, which is something that sparse autoencoders are making a lot easier.
Just being like, “ah, the ‘I'm deceiving my operator right now’ feature lit up”. So it's probably doing something like that, even if I don't know the exact reasoning. Though, obviously, I don't expect things to be quite that clean.
Q: I will just extend this question. Do you have a reason for believing that something like deceptive behavior or goal seeking is not going to be in the long, messy tail and is going to be in the clean thing that we can understand?
A: So my intuition is that the more important something is to the model, the more likely it is to get a lot of dedicated circuitry and parameters, and to not be a mess. And this, goal seeking, that's so important. How can you achieve your goals if you're not good at seeking your goals?
My intuition is that it's going to be reasonably prominent. But no one's tested this, and I don't even know if systems exist that have what I'd call goals right now. Or if this will ever happen. Which is another thing Interp could help us with.