Atticus Geiger - State of Interpretability & Ideas for Scaling Up
Transcript
Hey everyone, I'm Atticus, and I'm going to be talking about interpretability, just the state of things where we are, and my thoughts on how we should be scaling it up, and maybe some issues that currently exist with how people are doing it. Okay, cool. So I'll just hop in.
So I think roughly a way to try to trisect what interpretability might be able to do for us is, it might give us the ability to predict, control, and understand models. And I'm going to quickly go through each of these topics, and then talk to you about what I think are the state of art methods for doing each of these things.
And then I'll get to what it would look like to try to scale up interpretability. Alright. For prediction, really, I think, it hasn't changed too much since pre-large language models, and maybe like the 2016, AlexNet era. You can just train another classifier on model internal representations to predict a concept.
If a model has French text input, then you can train a classifier, “What language is this text?” And people have been doing this for a while. And you can also think of this within information theory. If you have an arbitrarily powerful probe, all probing is doing is measuring “What is the mutual information content between a concept and a hidden representation?”
Typically we train probes on labeled input data, but you also might think about training a probe on what behavior a model will have. And that could be more relevant in a safety setting. But that's pretty much the standard approach we have for predicting what models are thinking about. Alright, so then moving on to control. Control is a weird one, and I think steering has become a pretty popular way to think about interpretability, and what we might be able to do with it. But I think a lot of the time we're not as eager to run the really strong, powerful baselines that we have for controlling a model, which are, one, fine-tuning.
So you could steer a model to make it an angry model now, or you could just fine-tune the model to be an angry model with some angry data, basically. And that we should be running these baselines so we can actually figure out, yeah, when fine tuning isn't the thing we want to do. And so there's also parameter-efficient fine-tuning, so fine-tuning can be really expensive, and I think that's why we might not want to do it.
And then we have prompt engineering, which is another very competitive method for controlling models. You can also just ask the model to be angry, or ask the model to speak in French, or ask it to talk about the Golden Gate Bridge. So this is a really strong baseline for controlling the model, and it is unclear to me, like, how much work it's gonna take to get to a point where prompt engineering or fine-tuning isn't just the right way to go about manifesting some desired behavior in a model.
And methods like DSPy have the ability to even automatically generate prompts for you, given a specific goal, and it's just very cheap and easy to interact with. And then we get to steering, which is, I think, really just a collection of methods that are focused on doing an intervention to some model representation, giving it a little tweak to then get some desired behavior, whether that's removing something or enabling something.
And so a basic steering method is just: maybe you have a bunch of angry texts, and a bunch of calm texts, you run them through the model, you take the average angry representation, the average calm representation, find the difference between those two, and then that is a direction along which you can modulate the model representation to make it angrier, or to make it calmer.
So this is like a very basic supervised method for going about doing this. And this really would only work for binary concepts because you can't really do this in a k-ary way, or I don't think I've seen anyone doing it, but maybe Andy will tell me I'm wrong or something. Yep. So yeah, here's just the example, I love talking about weddings, I hate talking about weddings, these are the contrast prompts, which are then used to steer a model.
And there are other simple ways to get the directions we would want to do steering. So maybe you train a linear probe and whatever the direction that probe is sensitive to, that's also a thing you could use for steering or dimensionality reduction techniques like PCA or SVD will provide you directions that explain a lot of the variance of the data and they might be useful to use.
But really, what I think for steering, the most crucial thing to do, is just to establish a setting where steering is a good idea. I don't think that exists yet. And some ideas I have on what this might look like, is if you're really overusing your prompt space, so maybe you have a thousand things that you want to cram into the prompt, maybe some of those could be handled through steering, and you give the model more room to think.
But I have not seen this yet, and I think that's basically where we need to go first in this research direction to motivate whether it's one we want to pursue at all. Yeah, and then, last thing I want to mention, so representation fine-tuning is work that I'm involved with, and it's somewhere between parameter efficient fine-tuning and steering.
The sort of idea is that instead of just adding a vector to some hidden representation, you add a vector to the hidden representation, and then you also add a little adapter that is a low-rank matrix transformation, and all that does is basically read out a small amount of information, and then inject a small amount of information at a specific location.
So far, we've only benchmarked this method on just parameter-efficient fine-tuning sort of things, so just standard fine-tuning stuff, but I think where representation fine-tuning is going to shine the most is when we're trying to do surgically precise, layered compositions of steering.
And you can just think that this is a way of taking the power of supervised machine learning, and then using the targeted manipulations of steering. Cool. So, prediction and control. It's not really what I would consider the core idea of mechanistic interpretability. And a lot of times when we talk about it in the field, we use words like ‘basic science’ and ‘reverse-engineering neural networks into algorithms’.
And I personally think that my research and the research with my collaborators over the years take those sentences really seriously, in a way I don't think a lot of other research in the area does. So, we actually have a really mathematically and philosophically precise idea of when an algorithm is being realized by a deep learning model or really any dynamical system.
So I'll go through how I think about what the gold standard MechInterp looks like. Really, this is just a general phenomenon. So we explain dynamical systems with the algorithms that they implement. Humans doing stuff – this is, like, the entirety of cognitive science, and functionalist theory of mind, – is that, we're something like computers, and we can explain our behaviors and how we interact with something like an algorithm.
And then, we also have the same intuition probably about deep learning models – if you're Gary Marcus, you never had this intuition about deep learning models, but it's one I've always had myself. And then also a physical computer: the most canonical case of implementation. Like, that system is clearly one that implements a computational object.
And it was something very cool for me in my undergraduate was basically realizing that this notion of algorithmic implementation… we just hadn't precisified it, basically. We didn't have the math, and like the philosophy was pretty underdone. And basically, I think we do have an idea of talking about that, and we’ve made a lot of progress throughout my PhD, and up to now on what it looks like for computational explanation to be good and for an algorithm to actually be realized by a system.
So basically, our framework is that we represent a dynamical system and also a hypothesized algorithm as causal models. So all causal models are, they're just a bunch of circles and then arrows between those circles and then rules for determining the value of a circle based on its appearance.
It's just a discrete graph with some control flow structure to it, basically. And you can represent a neural network as that, where the variables are individual dimensions of vectors – or entire vectors; there's lots of ways to do it. And you can also represent algorithmic flow and just algorithms with causal models.
And so then, how do we understand the relationship between a simple causal model and a more complicated large causal model is with causal abstraction. And it's a pretty basic idea. You just go through this high-level model with a small number of variables, and you say, where are they localized in the deep learning model?
You'll make claims about where the variables in the algorithm are being causally realized, or where information flow is being mediated in the network. And then, now that we have the language to make these claims in the first place, the intervention experiments we do are going to be hypothesis testing, and actually figuring out whether the hypothesis that we formed is accurate.
Causal abstraction has these interesting origins in the last 10 or 15 years, it's like a mathematical field, for supporting people who are analyzing weather data and also brain data. And I think we're actually much better suited to do this sort of analysis in interpretability, in a really deep way.
This task of aggregating micro variables into macro variables, it's really hard with the weather and with the brain, because we don't have interventional access to these objects. It is very expensive and very difficult to do even basic intervention experiments in these two settings. But we really have arbitrary precision and measure in this setting.
We have an ideal sort of neuroscientist situation where we can manipulate and rerun and freeze and patch in whatever information we want. And so I think there's good reason that these new deep learning models that are so good are the first opportunity we have basically in human history to do this sort of hierarchical reverse-engineering.
And, yeah, that's why I'm really excited about interpretability, and I think it's a promising and unique moment in history for us. Yeah so the basic recipe that we can walk through a little example for, is: the hypothesis you have about the structure of some neural network is a high-level model, and then an alignment between the high-level variables and low-level variables.
Suppose we have just an algorithm that's adding together three numbers, X, Y, and Z, and we have a network that performs this. Then we have two runs of the network: 1,2, and 3. 4, 5, and 6. The sum X plus Y is 3 on the left, and is 9 on the right. If we do an intervention where we replace the 3 with a 9, the output should go up by 6.
It's like a very obvious statement about this very simple algorithm for addition. But then, what it means for these variables to be realized in the network, is that if we do the corresponding intervention on the neural network hidden states, we get the same output. We get 12. It's really a matter of aggregating a lot of these intervention experiments in showing that there's a counterfactual correspondence between quantities at the high-level that are easily understood and these seemingly uninterpretable low-level hidden representations.
And then, I think one of the most exciting developments that has come out of the research that I've been involved with, is that we have a supervised gradient descent-based method for localizing concepts in this way. The very short idea is that you freeze the network parameters and then you learn an alignment between the high level model and the low level model.
And that alignment is parameterized as a rotation matrix that basically tells you which linear subspaces of a representation you're going to do an intervention on. I'm not going to go into the details on that, you can look at the paper, but basically the idea is that if you have the data, counterfactual data telling you what a network should do if a concept has been intervened on, you can then use that data as a signal to localize that concept to model internal representations.
And if you wanna do that, localized concepts, I really think this is the state of the art method for doing so. Yeah. So now I'm going to spend the end and I'm going to give my hot takes and thoughts on what we should be doing to scale interpretability and talk about sparse autoencoders briefly.
Alright, so just a refresher on what sparse autoencoders are. Sparse autoencoders are trained on the hidden representations of neural networks to reconstruct them. They project them into a high-dimensional space and then they reconstruct that representation by projecting it down into a smaller space.
And these are basically the methods that people have for featurizing the residual stream or vector representations in deep learning models. And the typical SAE is just a linear layer followed by a ReLU, followed by another linear layer for decoding, and it is trained to just reconstruct the residual stream representations for tokens on a very large corpus.
And features are automatically labeled using some auto-interpretability pipeline, and current pipelines really rely just on, here are some examples for which this feature activated, and that's how they label these features, these new directions in the residual stream space. Yeah, and so one thing I think SAEs would be good for, is that they are unsupervised, so they might be able to come up with a feature we haven't thought of… but I really do think if you have a specific concept that you want to localize, SAEs will not be what you would ever use.
You would, if you want to predict, you'd use a probe, you want to control, you'd use something like steering or Reft (Representation finetuning for language models), and if you want to understand, you'd use something like distributed alignment search. And that SAEs will basically never give you features that will be used in critical pipelines.
Because if you can articulate what the feature needs to do, that's the same data you need to supervisedly train for that feature to get a better one. Yeah, and so here's what I want to see moving forward in scaling interpretability.
First, I think it's really insane that our goal is to mechanistically understand an object using an SAE that has been trained only on the observational distribution of that object. It's like doing science without ever doing experiments, basically. If we need mechanistic understanding, we need to simulate counterfactual states and networks with interventions, which we can very easily and cheaply do.
And we need to incorporate that type of training data into our scaling efforts for interpretability. There's no reason to think that you could jump up a whole level of Pearl’s causal hierarchy by just training on observational data. And now all of a sudden you've recovered a full mechanistic reverse engineering of an object.
I think that's like the strongest problem I have with SAEs right now. Then we need to pre-register our criterion for success. We need benchmarks. We need to say what it would look like to be successful at a task, and then hold ourselves accountable if we're not able to meet that.
I think that slipperiness is something you really want to avoid. SAEs have context-independent features, meaning they unpack the residual stream in the exact same way, no matter what the context you're in. And that is not an assumption I think you should make either. Also, I think we need to remain open-minded about non-linear features.
And, yeah, I have been working on benchmarking interpretability methods. Also have an alternative idea on how we should scale things up using counterfactual data. And I have some food for thought on non-linear representations, but these are like papers that I'm currently writing, or are about to be out, and I'd love to talk to you guys about it.
Q: So the first question we have is: Is there a way to use causal mediation analysis to benchmark SAE faithfulness?
A: Yeah, so I would say causal mediation analysis is exactly a special sub-case of causal abstraction, and absolutely we should be doing that.
And the benchmark that I have on this slide, RAVEL, is doing this. It's a benchmark for disentangling pieces of information and figuring out which components and features are causal mediators of that information.
Q: Great. The next question is: your causal abstraction idea has been published since 2021. It's been fairly influential and I think in the way people think about circuits, it's usually reflected, but people usually check whether hypotheses match the model in a very ad-hoc way rather than using the full formalism. Do you think that's a problem? Do you think it should change? Where should things go here? Is it fine?
A: I think that's a really good question. Honestly, yeah, using the full formalism is basically never necessary.
Formalism or math, if you can explain your ideas in a better way, you should totally go about doing it. But I do think that Redwood Research is… I was a big fan of their interpretability stuff, but I think they really held themselves accountable for hypothesis testing. I think they were realistic about what failure looked like, and their conclusion was interpretability is really hard, and they didn't think it was going to be a viable way forward.
I honestly think that, yeah, that is really the way we should be thinking about things. And I think some of the other interpretability efforts we see embrace the slipperiness of not having a defined criteria for success.
Q: What do you mean exactly by intervention data? Observation data is equivalent to intervention data when all variables have been observed.
A: That's, I don't think that's right in any way. Is that right? Okay, wait. What was the sentence?
Q: You can identify any causal estimate if you observed all the measures confounding. So in this particular scenario you don't have anything that's unobserved or unmeasured. So observational data is effectively the same as interventional data.
A: Yeah, so what I mean by interventional data is that the training data for SAEs are just the hidden representations, the values they take on when we put in an input, and the data I think we need to have is once we put in an input and we let it run out, what happens when we do interventions on the hidden states in there, then what will happen?
That's definitely data we don't have, and don't use, and it's really easy to generate.
Q: Excellent. Maybe one last question. Okay. If there is more. How well does distributed alignment search scale in terms of compute? What's the algorithmic complexity?
A: Yeah, the only trained parameters in distributed alignment search is the rotation matrix.
In the initial version of it, you had to choose a location you wanted to start doing your disentangling on. And then you would just train this sort of just rotation matrix, which is just, or however many dimensions you're gonna intervene on, that squared is basically what you want, or that, times the dimensionality.
But, the HyperDAS, which is like this project that I've been working on, which is really thinking about, like, “How do we scale up distributed alignment search in some crazy ways?” And we're trying to do things with a hyper-network that is able to select locations in the residual stream to do intervention experiments, and have that network take in natural language commands about the concept to localize.
It's pretty ambitious, and the preprint is gonna say, basically we put this end-to-end and the gradients worked, and it was able to be optimized to success in a very toy setting. But yeah, I think there basically are alternatives to scaling interpretability that aren't SAEs, and that this is food for thought about what they might look like.