← cgad.ski 2024-01-23

Backprop Isn't Just for Calculus

Given a directed graph, how can we count the number of paths from one vertex to another?

The natural approach is to build up our answer incrementally. For example, we might define x(b)x(b) to be the number of paths from a fixed vertex aa to a given endpoint b.b. (In our illustrated graph, we would make aa the left-most vertex.) Since there is only one trivial "path" from aa to itself, we can say at once that x(a)=1.x(a) = 1. From here, we can compute the values of x(b)x(b) for other vertices b,b, one by one, relying on values we've already determined.

If you're not familiar with this method, take a minute to figure it out! The widget below illustrates the reasoning involved. (Click or tap to select a vertex.)

We can also go backward, starting with the right-most vertex. Specifically, we could define x(a)x^*(a) to count the number of paths from a given vertex aa to a fixed endpoint b,b, assign x(b)=1,x^*(b) = 1, and proceed from right to left. This is the same as running our previous computation in a world where arrows go backward.

Although the intermediate values we're writing down are different, both "forward" and "backward" computations agree that there are 2222 paths through our graph.

Now, let's forget about path-counting and focus only on the nature of our forward computation. Suppose we modify a certain intermediate value x(b)x(b) and re-evaluate all other quantities x(c)x(c) that "depend" on x(b).x(b). How much will they change?

As it turns out, increasing x(b)x(b) by 11 will always increase x(c)x(c) by the number of paths from bb to c.c. We could write down the details, but the best proof is to stare at the widget above. (If you like linear algebra, you'll recognize the linearity at play.) We can say that the number of paths from bb to cc measures the sensitivity of x(c)x(c) to the value of x(b),x(b), or in other words how much x(c)x(c) depends on x(b).x(b).

What if we're interested in knowing how much a fixed variable x(c)x(c) depends on a set {x(bi)}\{ x(b_i) \} of other variables? We could run multiple forward computations through the graph to count paths from each bib_i to c,c, but it's faster to use our backward computation, starting at c,c, to compute all these numbers in the same pass.

This whole discussion works when the edges of our graph have weights, and we ask to sum up products of weights over paths. I'll spare the details, which you can work out for yourself. When our weights are real numbers, our forward computation describes a network of variables bound by linear functions, and our backward computation, known as the backpropagation algorithm, lets us determine the sensitivity of a given output on a whole set of inputs at once.

However, our weights need not be numbers. For example, let's say our variables are sets of strings of symbols. If xx and yy are two sets of strings, we'll define x+y=xy,xy={AB:Ax,By}\begin{align*} x + y & = x \cup y, \\ xy & = \{ A \cdot B : A \in x, B \in y\} \end{align*} where ABA \cdot B is a concatenation. If the "weight" on each edge is a singleton set holding a code for the edge, then a "sum of products over paths" is just a set of strings coding for paths. Our forward and backward computations above become two natural algorithms to enumerate paths through our graph.

Let Δ\Delta be some set of strings. Where x(c)/x(b)\partial x(c) /\partial x(b) denotes the set of "path strings" coding for paths from bb to c,c, we find that replacing x(b)x(b) with x(b)+Δx(b) + \Delta causes x(c)x(c) to become x(c)+Δx(c)x(b).x(c) +\Delta \frac{\partial x(c)}{\partial x(b)}. So, even though x(c)x(c) and x(b)x(b) are not numbers, there turns out to be a natural way to "measure the dependency" of one on the other by a single "quantity." If we're interested in a fixed output variable, the backward computation describes a way to propagate these quantifications of dependency through the graph.

Formally, this magic happens so long as operations on our "weights" satisfy the axioms of a semiring. (Since weights only need to be multiplied when they are on adjacent edges we can even think of them as morphisms in a pre-additive category, but I don't think this is an important generalization.) If you like abstract algebra, you might recognize the "sets of path strings" construction as almost being a free unital semiring over the edges of our graph. For the universal property to hold, we would need strings in our sets to count with multiplicity—that is, to consider multisets.

By universality, this is the end of the story: every other situation where our magic applies is a homomorphic image of the free semiring case. On the other hand, free semirings have a lot of homomorphic images. Can you think of something to backpropagate besides real numbers under addition and multiplication?

It's also interesting to consider what happens if our directed graph has cycles. Strictly speaking, our computation will not terminate, but in some cases it nevertheless admits a natural "limit" or least fixed point. (For example, we can still define sets of paths through a graph with cycles.) In these cases, the method of backpropagation still works. Can we use this idea?

After writing this post, I found a paper Generalizing Backpropagation for Gradient-Based Interpretability by Kevin Du et. al. describing two possible uses for semiring backpropagation in interpretability. I also found an earlier work Algorithmic Complexities in Backpropagation and Tropical Neural Networks by Ozgur Ceyhan describing backprop over the tropical (max-plus) semiring in the context of tropical neural networks.

← cgad.ski