← cgad.ski
2024-08-17

According to the **superposition hypothesis,** vectors of activations inside a neural network are best viewed as superpositions of "feature vectors"—directions in activation space with specific, interpretable meanings. This hypothesis is related to the classic connectionist idea of **distributed representation** and supported by the surprising fact that a low-dimensional linear projection
$y = \sum_i f_i x_i = F x$
can encode a much higher-dimensional sparse vector $x.$

How much information about $x$ can we expect a projection $y$ to encode, and how can we decode it? Suppose $x \in \left\{ 0, 1 \right\}^N$ is a binary vector coding for a set of $k$ "active features." If the dimension of $y$ is at least $\Omega(k \ln N)$ and $F$ is a suitable random matrix, every entry of $x$ can be decoded from $y$ by a simple linear classifier with high probability. However, the constant factor is far from ideal: as a rule of thumb, $y$ will need at least $8$ dimensions for each bit of the sparse signal. In comparison, relatively efficient methods suggested by compressed sensing can decode a random projection that stores a little more than **one bit per dimension**.

In this post, we review the idea of distributed representations and investigate the efficiency of "superposition representations" using some simple math and empirical results. Based on a fact about language model scaling laws, we also suggest that transformers might process superpositions of features that are too "efficiently packed" to be decoded by linear classifiers.

Inside a connectionist learning machine, information is represented by numbers. Unfortunately, unlike bytes of computer memory, individual numbers inside a deep neural network are usually a bad starting point for reverse engineering. One key challenge is **polysemanticity**, the tendency of individual neurons to code for multiple unrelated concepts. (This term was recently suggested by Anthropic in the context of large language models.) If the computation of a neural network involves interpretable "units of meaning," like variables on a stack or entries in a key-value store, they don't seem to be reliably coded by individual "units of memory."

What if units of meaning are coded by patterns of activity spread over *many* neurons? Let's call this general possibility a **distributed representation**. Already in 1986, Hinton, McClelland, and Rumelhart realized that the strengths and weaknesses of distributed representations mirror those of human cognition. (For example, see Chapter 3 of Parallel Distributed Processing.) These representations are typically **coarse codes**: schemes where units of memory hold unsystematically assigned, high-entropy properties of the signal to be encoded. Borrowing an example from PDP, a coarse code might store a position on the plane by letting each unit of memory tell us if the position activates a certain "receptive field."

Much like a polysemantic neuron does not describe a specific concept, a single unit of this code does not describe a specific position. However, a moderate number of large receptive fields cut the plane into many small regions, so the whole coarse code stores our position with relatively good accuracy. In fact, using smaller receptive fields would decrease the number of regions and make the code worse on average.

So far, you may be reminded of the place value systems normally used to store fixed-precision numbers. In comparison, our random receptive fields are less information-efficient but have some other interesting properties. For example, an inner product between the coarse codes of different positions approximates how close they are. It also turns out that *superimposing* coarse codes can encode small sets of positions. (Can you see why?)

In the roughly 40 years since *Parallel Distributed Processing* was published, it's become increasingly convincing that these strange representations are part of the right approach for building intelligent machines. Under this philosophy, "uninterpretable" neuron activations are a sign we're on the right track, rather than a reason to design more "interpretable" networks! However, little is known about how large neural networks use distributed representations in practice. How does the residual stream of a deep transformer model store information, and what information does it store?

Under the **superposition hypothesis**, a vector $y$ of activations is a superposition
$y = \sum_{i \in \mathcal X} f_i x_i = Fx$
of several "feature vectors" $f_i$ corresponding to a small subset $\Xc$ of non-zero entries in an underlying sparse vector $x.$ The indices $i$ are called "features" and are expected to have fairly specific and interpretable meanings. For $Fx$ to be a good code for many sparse vectors, it turns out that most features $f_i$ must be distributed across many dimensions. (In other terms, most neurons cannot have limited "receptive fields.") When this happens, the projection $Fx$ is a kind of distributed representation for $x.$

Using the superposition hypothesis, researchers have made some promising headway into decoding hidden activations in large language models. For example, see Anthropic's seminal work on **sparse autoencoders** from October 2023. (For a particularly simple, convincing example of how a transformer can discover features and represent them in linear superposition, see Neel Nanda's experiment on Othello-GPT's "world model".) However, it's still unclear how close we can get to decoding a transformer's entire "thought process" using this technique.

In this post, we focus on an easier question. Hypothetically, how much information about a sparse vector could our neural network store in an $n$ dimensional projection? To keep things really simple, we'll suppose that $x \in \left\{ 0, 1 \right\}^N$ is a $k$-sparse binary vector encoding a set of $k$ features chosen uniformly at random and that the feature vectors $f_i$ are random.

Let's see an example of how superpositions of random vectors can encode sets. We'll think of the set of all short strings as our feature set and use a hash function to map strings to binary feature vectors $f_i \in \{-1, 1\}^{128}.$

To represent the set of three strings above, just sum together their vectors.

Can we use the superposition $y$ to detect which strings belong in our set?

A natural idea is to check which feature vectors $f_i$ are unusually correlated with $y.$ Let's try this out: enter a string in the widget below to compute the element-wise product of its feature vector with $y.$ Positive numbers are colored blue.

The average value of this product should be close to $1$ if the query belongs to our set and close to $0$ if it does not. In other words, we can try to decode an entry $x_i$ of our "sparse signal" by the hard-thresholding perceptron
$\hat{x}_i(y) = \begin{cases}
1 : \displaystyle \frac{1}{n} \sum_{i=1}^n f_{ij} y_j > 1/2 \\
0 : \text{otherwise}.
\end{cases}$
In fact, this is the best we can do if we think of $y$ as a signal
$y = x_i f_i + Z$
carrying the message $x_i$ on the "frequency" $f_i$ and model the sum $Z$ of other feature vectors as isotropic noise independent from $x_i.$ This is a problem of *scalar estimation*, and the estimate $x_i \approx \langle f_i, y \rangle/n$ is known to electrical engineers as a matched filter. A concentration of measure argument related to the Johnson-Lindenstrauss lemma shows this method will recover any given sparse vector $x$ with probability at least $1 - P$ so long as
$n \ge 8 k(\ln P^{-1} + \ln N).$
Roughly, this is because the normalized inner product $\frac{1}{n} \langle f_i, y \rangle$ between one feature vector and a superposition of $k$ other features has a standard deviation of about
$\frac{\lVert f_i \rVert \lVert y \rVert }{\sqrt n} = \lVert f_i \rVert \sqrt{\frac{k}{n}},$
and therefore will concentrate tightly enough around its expected value for the estimator $\hat{x}_i$ to succeed so long as $n \ge C \cdot k.$ In our example where $n = 128$ and $k = 3,$ the probability that any given test $\hat{x}_i$ will fail is close to $1/2000.$ However, this decoding method only scratches the surface of the true "channel capacity" of $y.$

To start understanding why, suppose we know that only one feature is active. Then, even in the presence of noise, we can just choose the feature with maximum correlation. If only two features are active, we might look at the few features whose correlation with $y$ is the highest, subtract each from the signal one at a time, and check in each case if the residual $y - f_i$ is well-explained by a single remaining feature. In general, we get the sense that letting our estimates for each scalar $x_i$ "communicate" in some way might make decoding more reliable without sacrificing too much efficiency.

In the mid-2000s, the field of **compressed sensing** studied the problem of efficiently and stably decoding a sparse vector $x$ from random linear projections. Relying on more subtle properties of the random matrix $F,$ Candes, Romburg and Tao showed in 2005 that $x$ coincides with the solution of the linear program
$\begin{align*}
\textbf{minimize} & \quad \lVert x \rVert_1 \\
\textbf{subject to} & \quad F x = y
\end{align*}$
under surprisingly mild conditions. In particular, we can often recover $x$ from this linear program even when $y$ has many too few dimensions for the estimates $\hat{x}_i$ above to be accurate. In 2009, Donoho, Maleki and Montanari showed that an iterative thresholding algorithm called approximate message passing, inspired by message passing in graphical models, succeeds under about the same conditions.

It's not easy to find the best conditions on $n,$ $k$ and $N$ under which sparse recovery is guaranteed for a certain distribution of random matrices. Furthermore, results from the compressed sensing literature often assume that the ratio $n/N$ and sparsity level $k/N$ are bounded below by a constant. This makes sense in classical applications of compressed sensing, but for the superposition hypothesis it is more natural to think of $N$ as being exponentially large compared to $n$ and $k.$ I'm (so far) not aware of good bounds that apply in this regime. However, our empirical results suggest that, for moderate values of $N$ and $k \ll N,$ $x$ can be decoded with high probability so long as
$n \ge k (\log_2(N/k) + 1/\ln 2),$
where the RHS is a relatively tight upper bound on the number of bits needed to represent a $k$-element set. In other terms, we are saying that only **one dimension per bit** is enough. (I have no idea why this is the case. Please contact me if you do!) As we will show empirically, this is a big improvement over matched filters.

The remainder of this post is divided into two parts. First, we'll make an empirical comparison of matched filter decoding with a compressed sensing method. Then, we propose our tentative argument that transformers encode more information in superposition than matched filters can reliably decode. We call this the **dense superposition hypothesis**.

For $N \in \left\{ 2^8, 2^9, \dots, 2^{16} \right\}$ and $k \in \{1, \dots, 100\},$ let's see how large the dimension $n$ of the random projection $y = F x$ needs to be for $x$ to be recoverable by either matched filters or by a simple compressed sensing algorithm. (The code used for these experiments is available here. As our compressed sensing method we used weakly regularized lasso regression optimized with coordinate descent—a very conceptually simple algorithm.)

Blue points on the graph stand for completely successful decodings, while red points are decodings that failed on at least one coordinate. The dashed lines show the relations $\begin{align*} n & = 8k(\ln N - 3), \\ n & = k (\log_2(N/k) + 1/\ln 2) \approx \log_2 \binom N k. \end{align*}$

Let's first look at the performance of matched filters. Earlier, we said that a Johnson-Lindenstrauss-style argument shows
$n \ge 8 k (\ln P ^{-1} + \ln N)$
is enough to guarantee, with probability $1 - P,$ that all entries of $x$ are decoded by the perceptrons
$\hat{x}_i(y) = \begin{cases}
1 : \displaystyle \frac{1}{n} \langle f_i, y \rangle \ge 1/2 \\
0 : \text{otherwise}.
\end{cases}$
To see why, consider the inner product
$\frac{1}{n}\left\langle f_i, \sum_{j = 1}^k f_j \right\rangle$
of one feature $f_i$ with a sum of $k$ independent random features $f_j.$ This is distributed like $S_{nk}/n$ where $S_{nk}$ is a sum of $nk$ independent Rademacher variables. By a Chernoff bound,
$\begin{align*}
\P(S_{nk} / n \ge 1/2) & = \P(S_{nk} / nk \ge 1/2k) \\
& \le \exp\left( -\frac{n}{8k} \right).
\end{align*}$
In fact, it can be shown that this is the best possible exponential rate with respect to $n$ for large $k.$ This is where the magic constant $8$ comes from: it is the best constant for which
$n \ge 8 k \ln P ^{-1}$
is enough, for large $k$ and small $P,$ for a matched filter to be accurate on a superposition of $k$ random features with probability at least $1 - P.$ (I strongly suspect, but have not actually proven, that $8$ is the best possible constant for any distribution of independent feature vectors.) To guarantee that *every* matched filter is accurate at the same time, the usual strategy is to take a union bound
$\P(\exists i, \hat{x}_i \neq x_i) \le N \P(\hat{x}_i \neq x).$
This gives the condition $n \ge 8k (\ln P ^{-1} + \ln N)$ above. However, this conclusion is pessimistic to the extent that the failures $\hat{x}_i \neq x_i$ are correlated.

Empirically, we find that the minimum $n$ required by matched filters grows linearly with respect to $k$ for fixed $N,$ and that the required dimensions per feature $n/k$ grows like $8 \ln N.$ So far, this agrees with the Johnson-Lindenstrauss bound. However, it turns out that the ratio $n/k$ needs to be a little *smaller* than $8\ln N$ for matched filters to fail with moderate probability. For the range of parameters we considered, it turns out that the inequality
$n \le 8k(\ln N - 3)$
accurately predicts the regime where matched filters fail reliably, and at least $8k(\ln N - 1)$ dimensions are needed for a reasonable chance of reconstruction.

Now, let's consider our compressed sensing results. Conditions for sparse recovery in the CS literature often take the form $k \le C \cdot \frac{n}{\ln (N/n)},$ meaning that $n$ need be only some constant proportion larger than $k$ so long as the "aspect ratio" $N/n$ of the matrix $F$ is bounded above. (For example, see Section 1.3 from Candes et al..) However, for our purposes it is more natural to consider a stronger requirement of the form $k \le C \cdot n / \ln(N/k),$ which can be rewritten as $n \ge C ^{-1} \cdot k \ln (N/k).$ There is a simple proof, relying only on the Johnson-Lindenstrauss lemma, that this kind of requirement is enough to upper bound restricted isometry constants. However, I haven't yet understood how to prove a good sparse recovery guarantee with small $C^{-1}$ in the regime we consider here.

Empirically, it turns out that $C = \ln 2$ is close to the optimal value, and
$n \ge k (\log_2(N/k) + 1/\ln 2)$
is always a *sufficient* condition for sparse recovery. This has a surprising interpretation: since the RHS is a close approximation for the entropy $\log_2 \binom N k$ of a $k$-element subset when $k \ll N,$ it appears that the $x$ can be recovered so long as $y$ encodes not much more than "one bit per dimension."

To compare the performance of our two methods, let's differentiate these bounds with respect to $k.$ For sparse recovery to succeed, we need about $\frac{d}{dk} k (\log_2(N/k) + 1/\ln 2) = \log_2(N/k)$ more dimensions per additional active feature. On the other hand, matched filters need around $8 (\ln N - 1) \approx 5.5 \log_2 N - 8$ more dimensions per active feature. When $N = 2^{16},$ our respective marginal prices per feature are about $16 - \log_2 k$ and $80.$

These prices can easily vary by a factor of $8$ for moderate values of $k.$ Roughly, this means that matched filters can require about $8$ "bits per dimension." Given $n = 1000$ dimensions—a typical number for the residual layer of a transformer—this is the difference between reliably storing a set with $100$ elements and inconsistently storing a dozen. In a regime where $k$ is at least some fraction of $N,$ matched filters are inefficient in an *asymptotic* sense: the required bits per dimension diverges to infinity. (This is the normal regime of compressed sensing.)

In practice, how good are transformers at dealing with sparse superpositions of features? Can they reliably process superpositions that are too "densely packed" to be decoded by matched filters? It might be possible for a neural network to run a computation on an efficiently encoded sparse signal *without leaving superposition*—that is, without first "decoding" features and storing them as activations. Let's call this the **dense superposition hypothesis**.

The simple idea that activation vectors code for anything like uniformly random subsets of features is almost certainly wrong. For example, some features may be much more common than others. Furthermore, in this post we have supposed the feature vectors $f_i$ are random, but a neural network is in principle free to choose them optimally. However, we can still draw a distinction between codes that can be reliably decoded by matched filters and ones that cannot, and we might hope that the gap between the "ideal" information density of a superposition code and the information efficiency attained by matched filter decodings looks something like what we described above. (To the best of my understanding, random feature vectors are essentially optimal when the set of active features is uniformly random.)

One very tentative argument for the dense superposition hypothesis is based on an observation by Pavan Katta about scaling laws. It goes as follows.

Consider the following informal assumptions:

More accurate models need to represent more features.

A model's ability to represent features is bottlenecked by its residual stream dimension.

Together, these imply that the performance of a model should be strongly dependent on its dimension. However, according to empirical "scaling laws," performance depends weakly on dimension when the number of non-embedding parameters is held fixed. For example, Kaplan et al. found that a model with $6$ layers and $4288$ dimensions performed very similarly to a model with $48$ layers and $1600$ dimensions, given that each had around $250$ million parameters. In other examples, similarly performing models differed in dimension by more than a factor of $10.$

If features must be decodable by matched filters, then the relatively dismal performance of matched filters we saw above makes assumption (2) seem likely. It would follow that assumption (1) fails: models must be able to sacrifice features without sacrificing performance. Given our intuition that features code for "concepts," this is hard to believe. On the other hand, if the dense superposition hypothesis is true, assumption (2) may fail: our model might not be bottlenecked by the "information capacity" of the residual stream but instead by its ability to *use* its capacity. In lower dimensions, features suffer more interference but take less parameters to "address," so a model of smaller dimension may not be at an utter disadvantage.

If the dense superposition hypothesis is true, sparse autoencoders will never be enough to fully decode our models. Perhaps a subset of important features are visible—like features that are important for attention—but many more undecoded features are being processed within MLP layers. This matches recent observations that sparse autoencoders are only "scratching the surface" of language models. On the other hand, it's still very unclear to me whether MLPs can parameterize reliable computations on "efficiently encoded" superpositions, and it's possible that sparse autoencoders are limited for other reasons. (In fact, researchers are Anthropic were already aware that compressed sensing methods are stronger than linear decoders, but worried that they may be too strong for their applications.) If you're also interested in testing the dense superposition hypothesis, I'd like to get in touch.

An ideal encoding for a $k$ element subset of $\{1, \dots, N \}$ takes $\log_2 \binom N k$ bits. Throughout this post, we claimed this is well-approximated by $k (\log_2 (N/k) + 1/\ln 2)$ when $k \ll N.$ Let's check why.

Our usual tools to estimate the binomial coefficient are the inequalities $\left( \frac{N}{k} \right)^k \le \binom N k \le \left( \frac{N e}{k} \right)^k.$ (To prove the upper bound, remember that $n! \le (n/e)^n.$) If $k = o(N),$ meaning the ratio $N/k$ tends to infinity, then the factor $e^k$ is small when we take a logarithm, and overall $\ln \binom N k = k \ln (N/k) + O(k) \sim k \ln(N/k).$ Note that this is sublinear in $N.$

Asymptotically for large $N,$ the $O(k)$-sized gap between our bounds on $\ln \binom N k$ widens when the ratio $k/N$ is bounded away from $0$ and $1.$ To get a better approximation in this regime, we can plug the leading-order Stirling approximation $\ln n! = n (\ln n - n) + O(\ln n)$ into the binomial coefficient: $\begin{align*} \ln \binom N k & = (N - k) \ln\left( \frac{N}{N - k} \right) \\ & + k \ln\left( \frac{N}{k} \right) + O(\ln N). \end{align*}$ So if $k \sim S N,$ $\ln \binom N k = N H(S) + O(\ln N)$ where $H(S) = -S \ln S - (1 - S)\ln(1 - S)$ is the familiar binary entropy function. In particular, the number $\binom N k$ of subsets to encode grows exponentially. (Binary entropy does not appear as a surprise, since an $S$-sparse subset of a large set is very similar to a process of independent Bernoulli variables with mean $S.$) Since $H(S) = -S \ln S + S + O(S^2)$ with $-S \ln S$ dominating when $S \to 0,$ we find that the upper bound $\ln \binom N k \le k (\ln (N/k) + 1) = N (-S \ln S + S)$ is very reasonable for small $S,$ in the sense that it overestimates the rate of growth $\ln \binom N k / N$ by only $O(S^2).$ In comparison, the lower bound underestimates this rate by $S + O(S^2).$

As this suggests, I've found that the upper bound is better as a rule of thumb. When $N = 65536$ and $k$ ranges between $20$ and $5000,$ Mathematica tells me the approximation $\log_2 \binom N k \approx k (\log_2(N/k) + 1/ \ln 2)$ if off by no more than $1.4\%.$