Reverse Mode Automatic Differentiation
The first time they explained to me the backpropagation algorithm, it was done by simbolically derive from scratch the update rule of the weights of a neural network. Obviously, this approach is suitable only when the architecture is fairly simple, making it unfeasable for more complex and larger architectures. Common deep learning frameworks such as PyTorch, allows you to create neural networks by means of applying operations to tensors. Then, to get the numerical gradient of that network (no matter how complex) you can simply call the backward() method. This post is my personal attempt at explaining the basic idea of the algorithm behind that method: reverse mode automatic differentiation.
Computation Graphs
Having an efficient and automatic way to precisely compute derivatives on a computer is extremely important for many scientific areas. Let's assume for example that we have a simple function \(f(x) = 2 [\sin(x^3)^2]\) for which we want to compute its derivative. In a computer program we would represent this function straightforwardly like this:
def f(x: float) -> float:
return 2 * math.sin(x**3)**2
def f(x: float) -> float:
v_0 = x**3
v_1 = sin(v_0)
v_2 = v_1**2
v_3 = 2*v_2
return v_3
With this compact graphical representation we can show exectly dependencies between functions and the flow of values though the computational graph. I will use them in this post because they happen to be also quite useful to visualize the gradients flow through the computation graph, which may help to better ground the concepts.
The Chain Rule
The core idea behind automatic differentiation is also a powerful concept in math that many of us may remember from high school calculus: the chain rule. \[ \frac{\partial s}{\partial t} = \frac{\partial s}{\partial u} \frac{\partial u}{\partial t} \] Briefly speaking, this formula basically tells us that if a value \(s\) depends directly on another variable \(u\), which in turn depends on \(t\), then the way \(s\) changes with respect to \(t\) can be found by multiplying how \(s\) changes with \(u\) by how \(u\) changes with \(t\). If we want to avoid the word-twisting definition and stick to plain terms, the formula says that to know how much a variable changes with respect to another, we need to multiply how their intermediate values changes with respect to their inputs. Intuitively, it traces how a small change in \(t\) ripples through \(u\) and finally affects \(s\).
Sometimes this formula is expressed with \(s = f(x)\) and \(t = x\), but I deliberately wrote it in the most general form to make more clear that this rule can be applied repeatedly for each intermediate variable. We can apply it recursively until there are no more intermediate variables left (basically reaching the "end" of the expression). Following the previous analogy, it can be roughly compared to tracing how a small change in the input ripples through each step of the computation all the way to the final output.
Looking at the formula, we can note that there are basically three variables that must be decided when first applying the chain rule:
- The output variable \(s\) to differentiate.
- The input variable \(t\) we differentiate with respect to.
- The intermediate variable \(u\) from which we first apply the chain rule.
Let's get a little bit more concrete by deriving the example expression by means of the chain rule. We start by setting the input and output variables \(s = v_3\) and \(t = x\). Now, if we want to start from the output and move towards the input, we have to set \(u = v_2\), and then recursively apply the chain rule going backwards. This is also the most common way the chain rule is applied. I have put boxes around the terms of the expression where the recursive relation is applied and I also delimited the expansion with parentheses, to make the direction of the expansion more clear. \[ \begin{align} \frac{\partial v_3}{\partial x} &= \frac{\partial v_3}{\partial v_2} \hbbox{\frac{\partial v_2}{\partial x}}\\[.5em] &= \frac{\partial v_3}{\partial v_2}\left( \frac{\partial v_2}{\partial v_1}\hbbox{\frac{\partial v_1}{\partial x}} \right ) \\[.5em] &= \frac{\partial v_3}{\partial v_2}\left( \frac{\partial v_2}{\partial v_1} \left ( \frac{\partial v_1}{\partial v_0}\frac{\partial v_0}{\partial x}\right ) \right ) \\[.5em] \end{align} \]
Conversely, we can also apply the rule on the opposite order. To do so, we choose \(u = v_0\) and recursively apply the rule until we reach the output. It may seems weird initially, but I suggest to focus exclusively on the symbolic application of the formula (using the computational graph as a guide at each step may also help). \[ \begin{align} \frac{\partial v_3}{\partial x} &= \hbbox{\frac{\partial v_3}{\partial v_0}}\frac{\partial v_0}{\partial x}\\[.5em] &= \left ( \hbbox{\frac{\partial v_3}{\partial v_1}}\frac{\partial v_1}{\partial v_0} \right) \frac{\partial v_0}{\partial x}\\[.5em] &= \left( \left ( \frac{\partial v_3}{\partial v_2}\frac{\partial v_2}{\partial v_1} \right ) \frac{\partial v_1}{\partial v_0} \right) \frac{\partial v_0}{\partial x}\\[.5em] \end{align} \]
We end up with the same final derivative, but the terms can be expanded in two different orders. In the first expansion, the terms accumulate from left to right (starting from the output), and in the other, from right to left (starting from the input). In other words, in the first case, we assume that the leftmost term is already known and expand the rightmost term. \[ \frac{\partial s}{\partial t} = \frac{\partial s}{\partial u} \hbbox{\frac{\partial u}{\partial t}} \] In the second case we do the opposite, expanding the leftmost term while assuming the other is already known. \[ \frac{\partial s}{\partial t} = \hbbox{\frac{\partial s}{\partial u}} \frac{\partial u}{\partial t} \] If we look at both the expanded formulas, we can see that if we compute these terms following the order imposed by the parentheses, we are in fact accumulating the derivatives (just like in a reduction operation) in two different orders.
So far the discussion might sound like symbolic math, but automatic differentiation actually does something different. It never builds a giant symbolic derivative of your whole expression. Instead, it computes the numerical value of the derivative right at a single point. The trick is that the term that we assume known of the chain rule can be evaluated numerically at each step (where we expanded the chain rule in the example before). In this way, we only materialize the numerical value, accumulating (by mean of multiplication) the terms at each step. It is the order of application of the chain rule discussed so far that directly determines the direction of the accumulation of gradients. In fact, it turns out that the automatic differentiation algorithm is nothing more than the implementation of the evaluation of the chain rule.
Reverse Accumulation Mode
When we expanded the chain rule from outputs to inputs, we supposed that the partial \(\frac{\partial s}{\partial u}\) was known and recursively expanded the other partial in the next step. To implement this, for each intermediate node \(v_i\) of the expression, we compute simbolically \[ \overline{u} = \frac{\partial s}{\partial u} \] where in our case \(u = v_{i = 0\dots3}\) and \(s = v_3\). These quantities are called adjoints or co-tangents in the literature. Since the order of appearance of terms during the expansion is from outputs to inputs, we must follow that same order. \[ \begin{align} \overline{v_3} &= 1 \\[.5em] \overline{v_2} &= \overline{v_3} \cdot \frac{\partial{v_3}}{\partial{v_2}} = \overline{v_3} \cdot 2 \\[.5em] \overline{v_1} &= \overline{v_2} \cdot \frac{\partial{v_2}}{\partial{v_1}} = \overline{v_2} \cdot 2v_1\\[.5em] \overline{v_0} &= \overline{v_1} \cdot \frac{\partial{v_1}}{\partial{v_0}} = \overline{v_1} \cdot \cos{v_0} \\[.5em] \overline{x} &= \overline{v_0} \cdot \frac{\partial{v_0}}{\partial{x}} = \overline{v_0} \cdot 3x^2 \\[.5em] \end{align} \] If you look carefully at the derivations, each adjoint computation depends directly on values/adjoints that comes after on the computational graph. This is why this computations should be performed after the computational graph has been entirely evaluated. In the literature we say that we do a forward pass of the computational graph to get all the intermediate values instantiated. Then, we perform a backward pass, where we accumulate the gradients in reverse. That's why we call this method reverse mode (also called backpropagation in ML lingo).
def df(x: float) -> float:
# Forward evaluation of the function
v_0 = x**3
v_1 = math.sin(v_0)
v_2 = v_1**2
v_3 = 2*v_2
# Reverse gradient accumulation
dv_3 = 1
dv_2 = dv_3 * 2
dv_1 = dv_2 * 2 * v_1
dv_0 = dv_1 * math.cos(v_0)
dx = dv_0 * 3 * x**2
# Return value, gradient of the function
return dx
If executed, this function will return the exact value of the derivative at the specified point. You can verify that this function returns exactly the derivative as intended by comparing it with the symbolical form \(f(x) = 12x^2 \cos{x^3} \sin{x^3}\).
As we can see from the code, we have two computational graphs. One for the function and the other for the gradient computation. In reality, we represent them directly on the same graph, indicating both the flow of values and of gradients at the same time. Note that the operation applied on gradients is not explicitly shown, because gradients are implicitly multiplied at each step.
For now, we just hard coded a specific execution of the reverse mode automatic differentiation algorithm. Obviously, we can make a much more generic implementation based on this idea which doesn't involve writing each time each intermediate step and their corresponding gradient computation explicitly. To do that, we need to implement a basic mechanism to keep track of intemediate values and their dependencies in a computation. A very simple way to do so is to represent both values and operations as small objects that remember how they were computed.
import math
from typing import List, Optional
# Generic value type in a computational graph.
# It's just a wrapper of a numeric value
class Value:
def __init__(self):
self.value = None
# Unary operation that tracks its single input
class UnaryOp(Value):
def __init__(self, x: Value):
super().__init__()
self.x = x
# Binary operation that tracks both of its inputs
class BinaryOp(Value):
def __init__(self, x: Value, y: Value):
super().__init__()
self.x = x
self.y = y
# A generic input variable
class Input(Value):
def __init__(self, value: float):
super().__init__()
self.value = value
self.grad = 0
def forward(self) -> None:
pass
def backward(self, df_dv: float) -> None:
self.grad = df_dv
# A generic constant
class Constant(Value):
def __init__(self, value: float):
super().__init__()
self.value = value
def forward(self) -> None:
pass
def backward(self, df_dv: float) -> None:
return
# Sin operation
class Sin(UnaryOp):
def __init__(self, x: Value):
super().__init__(x)
def forward(self) -> None:
self.x.forward()
self.value = math.sin(self.x.value)
def backward(self, df_dv: float) -> None:
dv_dx = df_dv * math.cos(self.x.value)
self.x.backward(dv_dx)
# Pow (by const) operation
class Pow(UnaryOp):
def __init__(self, x: Value, exp: Constant):
super().__init__(x)
self.exp = exp
def forward(self) -> None:
self.x.forward()
self.value = self.x.value ** self.exp.value
def backward(self, df_dv: float) -> None:
dv_dx = df_dv * self.exp.value * (self.x.value ** (self.exp.value - 1))
self.x.backward(dv_dx)
# Mul (by const) operation
class Mul(UnaryOp):
def __init__(self, x: Value, y: Constant):
super().__init__(x)
self.y = y
def forward(self) -> None:
self.x.forward()
self.value = self.x.value * self.y.value
def backward(self, df_dv: float) -> None:
dv_dx = df_dv * self.y.value
self.x.backward(dv_dx)
Each node in the graph that represents an operation stores the numerical value of that operation, and a pointer to its inputs. To evaluate the graph we use the forward() method, which computes the value produced by each single node. The highlight here is the backward() method, which computes the local derivative (dv_dx) and recursively pushes the accumulated gradient to its input by passing it to the argument of the next backward() call. With this extremely simple DSL we can create a static computational graph of our example function, evaluate it, and take its derivative (at a specific point). All of this without explicitly deriving the gradient.
# we evaluate it on x=10
x = Input(10)
y = Mul(Pow(Sin(Pow(x, Constant(3))), Constant(2)), Constant(2))
# evaluate the graph
y.forward()
# evaluate the gradient
y.backward(1)
# print value and gradient
print(f"f({x.value}) = {y.value}")
print(f"df_dx({x.value}) = {x.grad}")
We can verify that our solution is working with a battle-tested autodiff framework like Pytorch. The following snippet will return the same numeric values.
import torch
x = torch.Tensor([10])
x.requires_grad = True
y = torch.pow(torch.sin(torch.pow(x, 3)), 2) * 2
y.backward(gradient=torch.ones(1))
print(f"f({x.item()}) = {y.item()}")
print(f"df_dx({x.item()}) = {x.grad.item()}")
Notice that when we first called backward() we passed the value 1, which is the gradient of y with respect to itself. At the end of the backward() method, each input node will contain the gradient of the node with respect to the variable from which we called the method (in our case y). This is why we print x.grad value in the snippet above. You can also notice that, just as before, we need to first evaluate the entire graph before executing the backward pass. The cool thing about the static graph implementation is that you can then re-evaluate it as many times as you like.
# evaluate on x=5
x.value = 5
y.forward()
y.backward(1)
print(f"f({x.value}) = {y.value}")
print(f"df_dx({x.value}) = {x.grad}")
# evaluate on x=3
x.value = 3
y.forward()
y.backward(1)
print(f"f({x.value}) = {y.value}")
print(f"df_dx({x.value}) = {x.grad}")
With this simple implementation we can now define any function composed of \(sin\), exponentials and multiplications by constants. There is a problem though that I haven't talked about until now. In our example, all operations were sequential in nature, but what happens when, for example, we have multiple inputs? Consider for example this slightly more complex function \(f(x, y) = \cos([x^2 + \sin(y)] \cdot x)\) and its corresponding computational graph
When we try to apply the chain rule to this expression to differentiate with respect to the input \(x\), we can immediately see that it is possible to reach it from two different paths along the graph. This leads to two different expansions of the chain rule.
\[ \frac{\partial v_4}{\partial v_3} \frac{\partial v_3}{\partial x} \quad \text{or} \quad \frac{\partial v_4}{\partial v_3} \frac{\partial v_3}{\partial v_2} \frac{\partial v_2}{\partial v_0} \frac{\partial v_0}{\partial x} \]To account for this situations, it is sufficient to sum the gradients corresponding to each different path. In fact, the multi-variate chain rule formula is in reality the following.
\[ \frac{\partial s}{\partial t} = \sum_i \frac{\partial s}{\partial u_i}\frac{\partial u_i}{\partial t} \]For each intermediate value \(u_i\). To support this more general case, we need to update our code by changing the gradient computation of the input nodes, accumulating the gradient at each call instead of setting a new value each time. We set up also a zero_grad() function, to reset gradients of input values in the computational graph. We can also make our code more flexible by introducing more general versions of some of the operations introduced before. For example Mul can be rewritten so that it supports multiplication between any values, not just between a value and a constant.
import math
from typing import List, Optional
class Value:
def __init__(self):
self.value = None
class UnaryOp(Value):
def __init__(self, x: Value):
super().__init__()
self.x = x
# call zero grad recursively until reaching an Input value
def zero_grad(self) -> None:
self.x.zero_grad()
class BinaryOp(Value):
def __init__(self, x: Value, y: Value):
super().__init__()
self.x = x
self.y = y
def zero_grad(self) -> None:
self.x.zero_grad()
self.y.zero_grad()
class Input(Value):
def __init__(self, value: float):
super().__init__()
self.value = value
self.grad = 0
def forward(self) -> None:
pass
def backward(self, df_dv: float) -> None:
# now we sum the incoming gradient to the actual gradient
self.grad += df_dv
# reset the gradient
def zero_grad(self) -> None:
self.grad = 0
class Constant(Value):
def __init__(self, value: float):
super().__init__()
self.value = value
def forward(self) -> None:
pass
def backward(self, df_dv: float) -> None:
pass
def zero_grad(self) -> None:
pass
class Sin(UnaryOp):
def __init__(self, x: Value):
super().__init__(x)
def forward(self) -> None:
self.x.forward()
self.value = math.sin(self.x.value)
def backward(self, df_dv: float) -> None:
dv_dx = df_dv * math.cos(self.x.value)
self.x.backward(dv_dx)
# Pow operation (exponential)
class Pow(UnaryOp):
# now accepts any value as an exponential
def __init__(self, x: Value, exp: Value):
super().__init__(x)
self.exp = exp
def forward(self) -> None:
self.x.forward()
self.exp.forward()
self.value = self.x.value ** self.exp.value
def backward(self, df_dv: float) -> None:
dv_dx = df_dv * self.exp.value * (self.x.value ** (self.exp.value - 1))
dv_dexp = df_dv * (self.x.value ** self.exp.value) * math.log(self.x.value)
self.x.backward(dv_dx)
self.exp.backward(dv_dexp)
# Mul operation (between any values)
class Mul(BinaryOp):
# now accepts any value as the second argument
def __init__(self, x: Value, y: Value):
super().__init__(x, y)
def forward(self) -> None:
self.x.forward()
self.y.forward()
self.value = self.x.value * self.y.value
def backward(self, df_dv: float) -> None:
dv_dx = df_dv * self.y.value
dv_dy = df_dv * self.x.value
self.x.backward(dv_dx)
self.y.backward(dv_dy)
Since only the input nodes will be tracing gradients, the zero_grad function just recursively call the same function of its children until it reaches an input node, which effectively sets its gradient to 0. We need this method because inputs acts like "sink" nodes that accumulate gradients during the backward pass. If we call the backward() method and then want to re-evaluate the gradient on another point, if we don't reset the gradient it would be added to the previous one, leading to an incorrect solution. Also, the possibility to reset gradients at will is particularly useful when we need to accumulate the gradient of multiple points to take its mean (which happens quite often in deep learning).
Conclusion
We saw the core idea behind reverse mode automatic differentiation. Starting from first principles, we saw how this algorithm came up quite naturally just by reasoning about how we expand the chain rule. We also saw a simple implementation from scratch in Python to demonstrate the simplicity of this algorithm. The discussion has been quite superficial to keep it light and easy to follow along. For example, I consciously avoided to discuss the forward mode, which I find much more simpler to understand. If you want to know more, I encourage you to dig deeper by looking at the awesome references at the end of this article.
References
- Evaluating Derivatives: Principles and Techniques of Algorithmic Differentiation - Griewank, Walther
- Automatic differentiation in machine learning: a survey - Atilim Gunes Baydin et. al.
- Reverse-mode automatic differentiation: a tutorial - Rufflewind
- Automatic Differentiation - Lei Mao