← cgad.ski 2024-08-17

When Numbers Are Bits: How Efficient Are Distributed Codes?

According to the superposition hypothesis, vectors of activations inside a neural network are best viewed as superpositions of "feature vectors"—directions in activation space with specific, interpretable meanings. This hypothesis is related to the classic connectionist idea of distributed representation and supported by the surprising fact that a low-dimensional linear projection y=ifixi=Fxy = \sum_i f_i x_i = F x can encode a much higher-dimensional sparse vector x.x.

How much information about xx can we expect a projection yy to encode, and how can we decode it? Suppose x{0,1}Nx \in \left\{ 0, 1 \right\}^N is a binary vector coding for a set of kk "active features." If the dimension of yy is at least Ω(klnN)\Omega(k \ln N) and FF is a suitable random matrix, every entry of xx can be decoded from yy by a simple linear classifier with high probability. However, the constant factor is far from ideal: as a rule of thumb, yy will need at least 88 dimensions for each bit of the sparse signal. In comparison, relatively efficient methods suggested by compressed sensing can decode a random projection that stores a little more than one bit per dimension.

In this post, we review the idea of distributed representations and investigate the efficiency of "superposition representations" using some simple math and empirical results. Based on a fact about language model scaling laws, we also suggest that transformers might process superpositions of features that are too "efficiently packed" to be decoded by linear classifiers.

Background and Main Ideas

Inside a connectionist learning machine, information is represented by numbers. Unfortunately, unlike bytes of computer memory, individual numbers inside a deep neural network are usually a bad starting point for reverse engineering. One key challenge is polysemanticity, the tendency of individual neurons to code for multiple unrelated concepts. (This term was recently suggested by Anthropic in the context of large language models.) If the computation of a neural network involves interpretable "units of meaning," like variables on a stack or entries in a key-value store, they don't seem to be reliably coded by individual "units of memory."

What if units of meaning are coded by patterns of activity spread over many neurons? Let's call this general possibility a distributed representation. Already in 1986, Hinton, McClelland, and Rumelhart realized that the strengths and weaknesses of distributed representations mirror those of human cognition. (For example, see Chapter 3 of Parallel Distributed Processing.) These representations are typically coarse codes: schemes where units of memory hold unsystematically assigned, high-entropy properties of the signal to be encoded. Borrowing an example from PDP, a coarse code might store a position on the plane by letting each unit of memory tell us if the position activates a certain "receptive field."

Much like a polysemantic neuron does not describe a specific concept, a single unit of this code does not describe a specific position. However, a moderate number of large receptive fields cut the plane into many small regions, so the whole coarse code stores our position with relatively good accuracy. In fact, using smaller receptive fields would decrease the number of regions and make the code worse on average.

So far, you may be reminded of the place value systems normally used to store fixed-precision numbers. In comparison, our random receptive fields are less information-efficient but have some other interesting properties. For example, an inner product between the coarse codes of different positions approximates how close they are. It also turns out that superimposing coarse codes can encode small sets of positions. (Can you see why?)

In the roughly 40 years since Parallel Distributed Processing was published, it's become increasingly convincing that these strange representations are part of the right approach for building intelligent machines. Under this philosophy, "uninterpretable" neuron activations are a sign we're on the right track, rather than a reason to design more "interpretable" networks! However, little is known about how large neural networks use distributed representations in practice. How does the residual stream of a deep transformer model store information, and what information does it store?

Under the superposition hypothesis, a vector yy of activations is a superposition y=iXfixi=Fxy = \sum_{i \in \mathcal X} f_i x_i = Fx of several "feature vectors" fif_i corresponding to a small subset X\Xc of non-zero entries in an underlying sparse vector x.x. The indices ii are called "features" and are expected to have fairly specific and interpretable meanings. For FxFx to be a good code for many sparse vectors, it turns out that most features fif_i must be distributed across many dimensions. (In other terms, most neurons cannot have limited "receptive fields.") When this happens, the projection FxFx is a kind of distributed representation for x.x.

Using the superposition hypothesis, researchers have made some promising headway into decoding hidden activations in large language models. For example, see Anthropic's seminal work on sparse autoencoders from October 2023. (For a particularly simple, convincing example of how a transformer can discover features and represent them in linear superposition, see Neel Nanda's experiment on Othello-GPT's "world model".) However, it's still unclear how close we can get to decoding a transformer's entire "thought process" using this technique.

In this post, we focus on an easier question. Hypothetically, how much information about a sparse vector could our neural network store in an nn dimensional projection? To keep things really simple, we'll suppose that x{0,1}Nx \in \left\{ 0, 1 \right\}^N is a kk-sparse binary vector encoding a set of kk features chosen uniformly at random and that the feature vectors fif_i are random.

Let's see an example of how superpositions of random vectors can encode sets. We'll think of the set of all short strings as our feature set and use a hash function to map strings to binary feature vectors fi{1,1}128.f_i \in \{-1, 1\}^{128}.

To represent the set of three strings above, just sum together their vectors.

Can we use the superposition yy to detect which strings belong in our set?

A natural idea is to check which feature vectors fif_i are unusually correlated with y.y. Let's try this out: enter a string in the widget below to compute the element-wise product of its feature vector with y.y. Positive numbers are colored blue.

The average value of this product should be close to 11 if the query belongs to our set and close to 00 if it does not. In other words, we can try to decode an entry xix_i of our "sparse signal" by the hard-thresholding perceptron x^i(y)={1:1ni=1nfijyj>1/20:otherwise.\hat{x}_i(y) = \begin{cases} 1 : \displaystyle \frac{1}{n} \sum_{i=1}^n f_{ij} y_j > 1/2 \\ 0 : \text{otherwise}. \end{cases} In fact, this is the best we can do if we think of yy as a signal y=xifi+Zy = x_i f_i + Z carrying the message xix_i on the "frequency" fif_i and model the sum ZZ of other feature vectors as isotropic noise independent from xi.x_i. This is a problem of scalar estimation, and the estimate xifi,y/nx_i \approx \langle f_i, y \rangle/n is known to electrical engineers as a matched filter. A concentration of measure argument related to the Johnson-Lindenstrauss lemma shows this method will recover any given sparse vector xx with probability at least 1P1 - P so long as n8k(lnP1+lnN).n \ge 8 k(\ln P^{-1} + \ln N). Roughly, this is because the normalized inner product 1nfi,y\frac{1}{n} \langle f_i, y \rangle between one feature vector and a superposition of kk other features has a standard deviation of about fiyn=fikn,\frac{\lVert f_i \rVert \lVert y \rVert }{\sqrt n} = \lVert f_i \rVert \sqrt{\frac{k}{n}}, and therefore will concentrate tightly enough around its expected value for the estimator x^i\hat{x}_i to succeed so long as nCk.n \ge C \cdot k. In our example where n=128n = 128 and k=3,k = 3, the probability that any given test x^i\hat{x}_i will fail is close to 1/2000.1/2000. However, this decoding method only scratches the surface of the true "channel capacity" of y.y.

To start understanding why, suppose we know that only one feature is active. Then, even in the presence of noise, we can just choose the feature with maximum correlation. If only two features are active, we might look at the few features whose correlation with yy is the highest, subtract each from the signal one at a time, and check in each case if the residual yfiy - f_i is well-explained by a single remaining feature. In general, we get the sense that letting our estimates for each scalar xix_i "communicate" in some way might make decoding more reliable without sacrificing too much efficiency.

In the mid-2000s, the field of compressed sensing studied the problem of efficiently and stably decoding a sparse vector xx from random linear projections. Relying on more subtle properties of the random matrix F,F, Candes, Romburg and Tao showed in 2005 that xx coincides with the solution of the linear program minimizex1subject toFx=y\begin{align*} \textbf{minimize} & \quad \lVert x \rVert_1 \\ \textbf{subject to} & \quad F x = y \end{align*} under surprisingly mild conditions. In particular, we can often recover xx from this linear program even when yy has many too few dimensions for the estimates x^i\hat{x}_i above to be accurate. In 2009, Donoho, Maleki and Montanari showed that an iterative thresholding algorithm called approximate message passing, inspired by message passing in graphical models, succeeds under about the same conditions.

It's not easy to find the best conditions on n,n, kk and NN under which sparse recovery is guaranteed for a certain distribution of random matrices. Furthermore, results from the compressed sensing literature often assume that the ratio n/Nn/N and sparsity level k/Nk/N are bounded below by a constant. This makes sense in classical applications of compressed sensing, but for the superposition hypothesis it is more natural to think of NN as being exponentially large compared to nn and k.k. I'm (so far) not aware of good bounds that apply in this regime. However, our empirical results suggest that, for moderate values of NN and kN,k \ll N, xx can be decoded with high probability so long as nk(log2(N/k)+1/ln2),n \ge k (\log_2(N/k) + 1/\ln 2), where the RHS is a relatively tight upper bound on the number of bits needed to represent a kk-element set. In other terms, we are saying that only one dimension per bit is enough. (I have no idea why this is the case. Please contact me if you do!) As we will show empirically, this is a big improvement over matched filters.

The remainder of this post is divided into two parts. First, we'll make an empirical comparison of matched filter decoding with a compressed sensing method. Then, we propose our tentative argument that transformers encode more information in superposition than matched filters can reliably decode. We call this the dense superposition hypothesis.

Empirical Results

For N{28,29,,216}N \in \left\{ 2^8, 2^9, \dots, 2^{16} \right\} and k{1,,100},k \in \{1, \dots, 100\}, let's see how large the dimension nn of the random projection y=Fxy = F x needs to be for xx to be recoverable by either matched filters or by a simple compressed sensing algorithm. (The code used for these experiments is available here. As our compressed sensing method we used weakly regularized lasso regression optimized with coordinate descent—a very conceptually simple algorithm.)

Blue points on the graph stand for completely successful decodings, while red points are decodings that failed on at least one coordinate. The dashed lines show the relations n=8k(lnN3),n=k(log2(N/k)+1/ln2)log2(Nk).\begin{align*} n & = 8k(\ln N - 3), \\ n & = k (\log_2(N/k) + 1/\ln 2) \approx \log_2 \binom N k. \end{align*}

Let's first look at the performance of matched filters. Earlier, we said that a Johnson-Lindenstrauss-style argument shows n8k(lnP1+lnN)n \ge 8 k (\ln P ^{-1} + \ln N) is enough to guarantee, with probability 1P,1 - P, that all entries of xx are decoded by the perceptrons x^i(y)={1:1nfi,y1/20:otherwise.\hat{x}_i(y) = \begin{cases} 1 : \displaystyle \frac{1}{n} \langle f_i, y \rangle \ge 1/2 \\ 0 : \text{otherwise}. \end{cases} To see why, consider the inner product 1nfi,j=1kfj\frac{1}{n}\left\langle f_i, \sum_{j = 1}^k f_j \right\rangle of one feature fif_i with a sum of kk independent random features fj.f_j. This is distributed like Snk/nS_{nk}/n where SnkS_{nk} is a sum of nknk independent Rademacher variables. By a Chernoff bound, P(Snk/n1/2)=P(Snk/nk1/2k)exp(n8k).\begin{align*} \P(S_{nk} / n \ge 1/2) & = \P(S_{nk} / nk \ge 1/2k) \\ & \le \exp\left( -\frac{n}{8k} \right). \end{align*} In fact, it can be shown that this is the best possible exponential rate with respect to nn for large k.k. This is where the magic constant 88 comes from: it is the best constant for which n8klnP1n \ge 8 k \ln P ^{-1} is enough, for large kk and small P,P, for a matched filter to be accurate on a superposition of kk random features with probability at least 1P.1 - P. (I strongly suspect, but have not actually proven, that 88 is the best possible constant for any distribution of independent feature vectors.) To guarantee that every matched filter is accurate at the same time, the usual strategy is to take a union bound P(i,x^ixi)NP(x^ix).\P(\exists i, \hat{x}_i \neq x_i) \le N \P(\hat{x}_i \neq x). This gives the condition n8k(lnP1+lnN)n \ge 8k (\ln P ^{-1} + \ln N) above. However, this conclusion is pessimistic to the extent that the failures x^ixi\hat{x}_i \neq x_i are correlated.

Empirically, we find that the minimum nn required by matched filters grows linearly with respect to kk for fixed N,N, and that the required dimensions per feature n/kn/k grows like 8lnN.8 \ln N. So far, this agrees with the Johnson-Lindenstrauss bound. However, it turns out that the ratio n/kn/k needs to be a little smaller than 8lnN8\ln N for matched filters to fail with moderate probability. For the range of parameters we considered, it turns out that the inequality n8k(lnN3)n \le 8k(\ln N - 3) accurately predicts the regime where matched filters fail reliably, and at least 8k(lnN1)8k(\ln N - 1) dimensions are needed for a reasonable chance of reconstruction.

Now, let's consider our compressed sensing results. Conditions for sparse recovery in the CS literature often take the form kCnln(N/n),k \le C \cdot \frac{n}{\ln (N/n)}, meaning that nn need be only some constant proportion larger than kk so long as the "aspect ratio" N/nN/n of the matrix FF is bounded above. (For example, see Section 1.3 from Candes et al..) However, for our purposes it is more natural to consider a stronger requirement of the form kCn/ln(N/k),k \le C \cdot n / \ln(N/k), which can be rewritten as nC1kln(N/k).n \ge C ^{-1} \cdot k \ln (N/k). There is a simple proof, relying only on the Johnson-Lindenstrauss lemma, that this kind of requirement is enough to upper bound restricted isometry constants. However, I haven't yet understood how to prove a good sparse recovery guarantee with small C1C^{-1} in the regime we consider here.

Empirically, it turns out that C=ln2C = \ln 2 is close to the optimal value, and nk(log2(N/k)+1/ln2)n \ge k (\log_2(N/k) + 1/\ln 2) is always a sufficient condition for sparse recovery. This has a surprising interpretation: since the RHS is a close approximation for the entropy log2(Nk)\log_2 \binom N k of a kk-element subset when kN,k \ll N, it appears that the xx can be recovered so long as yy encodes not much more than "one bit per dimension."

To compare the performance of our two methods, let's differentiate these bounds with respect to k.k. For sparse recovery to succeed, we need about ddkk(log2(N/k)+1/ln2)=log2(N/k)\frac{d}{dk} k (\log_2(N/k) + 1/\ln 2) = \log_2(N/k) more dimensions per additional active feature. On the other hand, matched filters need around 8(lnN1)5.5log2N88 (\ln N - 1) \approx 5.5 \log_2 N - 8 more dimensions per active feature. When N=216,N = 2^{16}, our respective marginal prices per feature are about 16log2k16 - \log_2 k and 80.80.

These prices can easily vary by a factor of 88 for moderate values of k.k. Roughly, this means that matched filters can require about 88 "bits per dimension." Given n=1000n = 1000 dimensions—a typical number for the residual layer of a transformer—this is the difference between reliably storing a set with 100100 elements and inconsistently storing a dozen. In a regime where kk is at least some fraction of N,N, matched filters are inefficient in an asymptotic sense: the required bits per dimension diverges to infinity. (This is the normal regime of compressed sensing.)

The Dense Superposition Hypothesis

In practice, how good are transformers at dealing with sparse superpositions of features? Can they reliably process superpositions that are too "densely packed" to be decoded by matched filters? It might be possible for a neural network to run a computation on an efficiently encoded sparse signal without leaving superposition—that is, without first "decoding" features and storing them as activations. Let's call this the dense superposition hypothesis.

The simple idea that activation vectors code for anything like uniformly random subsets of features is almost certainly wrong. For example, some features may be much more common than others. Furthermore, in this post we have supposed the feature vectors fif_i are random, but a neural network is in principle free to choose them optimally. However, we can still draw a distinction between codes that can be reliably decoded by matched filters and ones that cannot, and we might hope that the gap between the "ideal" information density of a superposition code and the information efficiency attained by matched filter decodings looks something like what we described above. (To the best of my understanding, random feature vectors are essentially optimal when the set of active features is uniformly random.)

One very tentative argument for the dense superposition hypothesis is based on an observation by Pavan Katta about scaling laws. It goes as follows.

Consider the following informal assumptions:

  1. More accurate models need to represent more features.

  2. A model's ability to represent features is bottlenecked by its residual stream dimension.

Together, these imply that the performance of a model should be strongly dependent on its dimension. However, according to empirical "scaling laws," performance depends weakly on dimension when the number of non-embedding parameters is held fixed. For example, Kaplan et al. found that a model with 66 layers and 42884288 dimensions performed very similarly to a model with 4848 layers and 16001600 dimensions, given that each had around 250250 million parameters. In other examples, similarly performing models differed in dimension by more than a factor of 10.10.

If features must be decodable by matched filters, then the relatively dismal performance of matched filters we saw above makes assumption (2) seem likely. It would follow that assumption (1) fails: models must be able to sacrifice features without sacrificing performance. Given our intuition that features code for "concepts," this is hard to believe. On the other hand, if the dense superposition hypothesis is true, assumption (2) may fail: our model might not be bottlenecked by the "information capacity" of the residual stream but instead by its ability to use its capacity. In lower dimensions, features suffer more interference but take less parameters to "address," so a model of smaller dimension may not be at an utter disadvantage.

If the dense superposition hypothesis is true, sparse autoencoders will never be enough to fully decode our models. Perhaps a subset of important features are visible—like features that are important for attention—but many more undecoded features are being processed within MLP layers. This matches recent observations that sparse autoencoders are only "scratching the surface" of language models. On the other hand, it's still very unclear to me whether MLPs can parameterize reliable computations on "efficiently encoded" superpositions, and it's possible that sparse autoencoders are limited for other reasons. (In fact, researchers are Anthropic were already aware that compressed sensing methods are stronger than linear decoders, but worried that they may be too strong for their applications.) If you're also interested in testing the dense superposition hypothesis, I'd like to get in touch.

Appendix: Asymptotics for the Binomial

An ideal encoding for a kk element subset of {1,,N}\{1, \dots, N \} takes log2(Nk)\log_2 \binom N k bits. Throughout this post, we claimed this is well-approximated by k(log2(N/k)+1/ln2)k (\log_2 (N/k) + 1/\ln 2) when kN.k \ll N. Let's check why.

Our usual tools to estimate the binomial coefficient are the inequalities (Nk)k(Nk)(Nek)k.\left( \frac{N}{k} \right)^k \le \binom N k \le \left( \frac{N e}{k} \right)^k. (To prove the upper bound, remember that n!(n/e)n.n! \le (n/e)^n.) If k=o(N),k = o(N), meaning the ratio N/kN/k tends to infinity, then the factor eke^k is small when we take a logarithm, and overall ln(Nk)=kln(N/k)+O(k)kln(N/k).\ln \binom N k = k \ln (N/k) + O(k) \sim k \ln(N/k). Note that this is sublinear in N.N.

Asymptotically for large N,N, the O(k)O(k)-sized gap between our bounds on ln(Nk)\ln \binom N k widens when the ratio k/Nk/N is bounded away from 00 and 1.1. To get a better approximation in this regime, we can plug the leading-order Stirling approximation lnn!=n(lnnn)+O(lnn)\ln n! = n (\ln n - n) + O(\ln n) into the binomial coefficient: ln(Nk)=(Nk)ln(NNk)+kln(Nk)+O(lnN).\begin{align*} \ln \binom N k & = (N - k) \ln\left( \frac{N}{N - k} \right) \\ & + k \ln\left( \frac{N}{k} \right) + O(\ln N). \end{align*} So if kSN,k \sim S N, ln(Nk)=NH(S)+O(lnN)\ln \binom N k = N H(S) + O(\ln N) where H(S)=SlnS(1S)ln(1S)H(S) = -S \ln S - (1 - S)\ln(1 - S) is the familiar binary entropy function. In particular, the number (Nk)\binom N k of subsets to encode grows exponentially. (Binary entropy does not appear as a surprise, since an SS-sparse subset of a large set is very similar to a process of independent Bernoulli variables with mean S.S.) Since H(S)=SlnS+S+O(S2)H(S) = -S \ln S + S + O(S^2) with SlnS-S \ln S dominating when S0,S \to 0, we find that the upper bound ln(Nk)k(ln(N/k)+1)=N(SlnS+S)\ln \binom N k \le k (\ln (N/k) + 1) = N (-S \ln S + S) is very reasonable for small S,S, in the sense that it overestimates the rate of growth ln(Nk)/N\ln \binom N k / N by only O(S2).O(S^2). In comparison, the lower bound underestimates this rate by S+O(S2).S + O(S^2).

As this suggests, I've found that the upper bound is better as a rule of thumb. When N=65536N = 65536 and kk ranges between 2020 and 5000,5000, Mathematica tells me the approximation log2(Nk)k(log2(N/k)+1/ln2)\log_2 \binom N k \approx k (\log_2(N/k) + 1/ \ln 2) if off by no more than 1.4%.1.4\%.

← cgad.ski