Automatic differentiation
1 Problem
Suppose we need to solve the following problem:
L(w) \to \min_{w \in \mathbb{R}^d}
Such problems typically arise in machine learning, when you need to find optimal hyperparameters w of an ML model (i.e. train a neural network). You may use a lot of algorithms to approach this problem, but given the modern size of the problem, where d could be dozens of billions it is very challenging to solve this problem without information about the gradients using zero-order optimization algorithms. That is why it would be beneficial to be able to calculate the gradient vector \nabla_w L = \left( \frac{\partial L}{\partial w_1}, \ldots, \frac{\partial L}{\partial w_d}\right)^T. Typically, first-order methods perform much better in huge-scale optimization, while second-order methods require too much memory.
2 Finite differences
The naive approach to get approximate values of gradients is Finite differences approach. For each coordinate, one can calculate the partial derivative approximation:
\dfrac{\partial L}{\partial w_k} (w) \approx \dfrac{L(w+\varepsilon e_k) - L(w)}{\varepsilon}, \quad e_k = (0, \ldots, \underset{{\tiny k}}{1}, \ldots, 0)
3 Forward mode automatic differentiation
To dive deep into the idea of automatic differentiation we will consider a simple function for calculating derivatives:
L(w_1, w_2) = w_2 \log w_1 + \sqrt{w_2 \log w_1}
Let’s draw a computational graph of this function:
Let’s go from the beginning of the graph to the end and calculate the derivative \dfrac{\partial L}{\partial w_1}:
Step | Function | Derivative | Scheme |
---|---|---|---|
1 | w_1 = w_1, w_2 = w_2 | \dfrac{\partial w_1}{\partial w_1} = 1, \dfrac{\partial w_2}{\partial w_1} = 0 | |
2 | v_1 = \log w_1 | \begin{aligned}\frac{\partial v_1}{\partial w_1} &= \frac{\partial v_1}{\partial w_1} \frac{\partial w_1}{\partial w_1}\\ &= \frac{1}{w_1} 1\end{aligned} | |
3 | v_2 = w_2 v_1 | \begin{aligned}\frac{\partial v_2}{\partial w_1} &= \frac{\partial v_2}{\partial v_1}\frac{\partial v_1}{\partial w_1} + \frac{\partial v_2}{\partial w_2}\frac{\partial w_2}{\partial w_1} \\&= w_2\frac{\partial v_1}{\partial w_1} + v_1\frac{\partial w_2}{\partial w_1}\end{aligned} | |
4 | v_3 = \sqrt{v_2} | \begin{aligned}\frac{\partial v_3}{\partial w_1} &= \frac{\partial v_3}{\partial v_2}\frac{\partial v_2}{\partial w_1} \\ &= \frac{1}{2\sqrt{v_2}}\frac{\partial v_2}{\partial w_1}\end{aligned} | |
5 | L = v_2 + v_3 | \begin{aligned}\frac{\partial L}{\partial w_1} &= \frac{\partial L}{\partial v_2}\frac{\partial v_2}{\partial w_1} + \frac{\partial L}{\partial v_3}\frac{\partial v_3}{\partial w_1} \\&= 1\frac{\partial v_2}{\partial w_1} + 1\frac{\partial v_3}{\partial w_1}\end{aligned} |
Note, that this approach does not require storing all intermediate computations, but one can see, that for calculating the derivative \dfrac{\partial L}{\partial w_k} we need \mathcal{O}(T) operations. This means, that for the whole gradient, we need d\mathcal{O}(T) operations, which is the same as for finite differences, but we do not have stability issues, or inaccuracies now (the formulas above are exact).
4 Backward mode automatic differentiation
We will consider the same function
L(w_1, w_2) = w_2 \log w_1 + \sqrt{w_2 \log w_1}
with a computational graph:
Assume, that we have some values of the parameters w_1, w_2 and we have already performed a forward pass (i.e. single propagation through the computational graph from left to right). Suppose, also, that we somehow saved all intermediate values of v_i. Let’s go from the end of the graph to the beginning and calculate the derivatives \dfrac{\partial L}{\partial w_1}, \dfrac{\partial L}{\partial w_1}:
Step | Derivative | Scheme |
---|---|---|
1 | \dfrac{\partial L}{\partial L} = 1 | |
2 | \begin{aligned}\frac{\partial L}{\partial v_3} &= \frac{\partial L}{\partial L} \frac{\partial L}{\partial v_3}\\ &= \frac{\partial L}{\partial L} 1\end{aligned} | |
3 | \begin{aligned}\frac{\partial L}{\partial v_2} &= \frac{\partial L}{\partial v_3}\frac{\partial v_3}{\partial v_2} + \frac{\partial L}{\partial L}\frac{\partial L}{\partial v_2} \\&= \frac{\partial L}{\partial v_3}\frac{1}{2\sqrt{v_2}} + \frac{\partial L}{\partial L}1\end{aligned} | |
4 | \begin{aligned}\frac{\partial L}{\partial v_1} &=\frac{\partial L}{\partial v_2}\frac{\partial v_2}{\partial v_1} \\ &= \frac{\partial L}{\partial v_2}w_2\end{aligned} | |
5 | \begin{aligned}\frac{\partial L}{\partial w_1} &= \frac{\partial L}{\partial v_1}\frac{\partial v_1}{\partial w_1} \\&= \frac{\partial L}{\partial v_1}\frac{1}{w_1}\end{aligned} \begin{aligned}\frac{\partial L}{\partial w_2} &= \frac{\partial L}{\partial v_2}\frac{\partial v_2}{\partial w_2} \\&= \frac{\partial L}{\partial v_1}v_1\end{aligned} |
4.1 What automatic differentiation (AD) is NOT:
- AD is not a finite differences
- AD is not a symbolic derivative
- AD is not just the chain rule
- AD is not just backpropagation
- AD (reverse mode) is time-efficient and numerically stable
- AD (reverse mode) is memory inefficient (you need to store all intermediate computations from the forward pass). :::
5 Important stories from matrix calculus
We will illustrate some important matrix calculus facts for specific cases
5.1 Univariate chain rule
Suppose, we have the following functions R: \mathbb{R} \to \mathbb{R} , L: \mathbb{R} \to \mathbb{R} and W \in \mathbb{R}. Then
\dfrac{\partial R}{\partial W} = \dfrac{\partial R}{\partial L} \dfrac{\partial L}{\partial W}
5.2 Multivariate chain rule
The simplest example:
\dfrac{\partial }{\partial t} f(x_1(t), x_2(t)) = \dfrac{\partial f}{\partial x_1} \dfrac{\partial x_1}{\partial t} + \dfrac{\partial f}{\partial x_2} \dfrac{\partial x_2}{\partial t}
Now, we’ll consider f: \mathbb{R}^n \to \mathbb{R}:
\dfrac{\partial }{\partial t} f(x_1(t), \ldots, x_n(t)) = \dfrac{\partial f}{\partial x_1} \dfrac{\partial x_1}{\partial t} + \ldots + \dfrac{\partial f}{\partial x_n} \dfrac{\partial x_n}{\partial t}
But if we will add another dimension f: \mathbb{R}^n \to \mathbb{R}^m, than the j-th output of f will be:
\dfrac{\partial }{\partial t} f_j(x_1(t), \ldots, x_n(t)) = \sum\limits_{i=1}^n \dfrac{\partial f_j}{\partial x_i} \dfrac{\partial x_i}{\partial t} = \sum\limits_{i=1}^n J_{ji} \dfrac{\partial x_i}{\partial t},
where matrix J \in \mathbb{R}^{m \times n} is the jacobian of the f. Hence, we could write it in a vector way:
\dfrac{\partial f}{\partial t} = J \dfrac{\partial x}{\partial t}\quad \iff \quad \left(\dfrac{\partial f}{\partial t}\right)^\top = \left( \dfrac{\partial x}{\partial t}\right)^\top J^\top
5.3 Backpropagation
Backpropagation is a specific application of reverse-mode automatic differentiation within neural networks. It is the standard algorithm for computing gradients in neural networks, especially for training with stochastic gradient descent. Here’s how it works:
- Perform a forward pass through the network to compute activations and outputs.
- Calculate the loss function at the output, which measures the difference between the network prediction and the actual target values.
- Commence the backward pass by computing the gradient of the loss with respect to the network’s outputs.
- Propagate these gradients back through the network, layer by layer, using the chain rule to calculate the gradients of the loss with respect to each weight and bias.
- The critical point of backpropagation is that it efficiently calculates the gradient of a complex, multilayered function by decomposing it into simpler derivative calculations. This aspect makes the update of a large number of parameters in deep networks computationally feasible.
5.4 Jacobian vector product
The power of automatic differentiation is encapsulated in the computation of the Jacobian-vector product. Instead of calculating the entire Jacobian matrix, which is computationally expensive and often unnecessary, AD computes the product of the Jacobian and a vector directly. This is crucial for gradients in neural networks where the Jacobian may be very large, but the end goal is the product of this Jacobian with the gradient of the loss with respect to the outputs (vector). The reason why it works so fast in practice is that the Jacobian of the operations is already developed effectively in automatic differentiation frameworks. Typically, we even do not construct or store the full Jacobian, doing matvec directly instead. Note, for some functions (for example, any element-wise function of the input vector) matvec costs linear time, instead of quadratic and requires no additional memory to store a Jacobian.
See the examples of Vector-Jacobian Products from the autodidact library:
lambda g, ans, x, y : unbroadcast(x, g),
defvjp(anp.add, lambda g, ans, x, y : unbroadcast(y, g))
lambda g, ans, x, y : unbroadcast(x, y * g),
defvjp(anp.multiply, lambda g, ans, x, y : unbroadcast(y, x * g))
lambda g, ans, x, y : unbroadcast(x, g),
defvjp(anp.subtract, lambda g, ans, x, y : unbroadcast(y, -g))
lambda g, ans, x, y : unbroadcast(x, g / y),
defvjp(anp.divide, lambda g, ans, x, y : unbroadcast(y, - g * x / y**2))
lambda g, ans, x, y : unbroadcast(x, g / y),
defvjp(anp.true_divide, lambda g, ans, x, y : unbroadcast(y, - g * x / y**2))
5.5 Hessian vector product
Interestingly, a similar idea could be used to compute Hessian-vector products, which is essential for second-order optimization or conjugate gradient methods. For a scalar-valued function f : \mathbb{R}^n \to \mathbb{R} with continuous second derivatives (so that the Hessian matrix is symmetric), the Hessian at a point x \in \mathbb{R}^n is written as \partial^2 f(x). A Hessian-vector product function is then able to evaluate
v \mapsto \partial^2 f(x) \cdot v
for any vector v \in \mathbb{R}^n.
The trick is not to instantiate the full Hessian matrix: if n is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store. Luckily, grad
(in the jax/autograd/pytorch/tensorflow) already gives us a way to write an efficient Hessian-vector product function. We just have to use the identity
\partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x),
where g(x) = \partial f(x) \cdot v is a new vector-valued function that dots the gradient of f at x with the vector v. Notice that we’re only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where we know grad
is efficient.
import jax.numpy as jnp
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
6 Code
7 Materials
- Autodidact - a pedagogical implementation of Autograd
- CSC321 Lecture 6
- CSC321 Lecture 10
- Why you should understand backpropagation :)
- JAX autodiff cookbook
- Materials from CS207: Systems Development for Computational Science course with very intuitive explanation.
- Great lecture on AD from Dmitry Kropotov (in Russian).
Footnotes
Linnainmaa S. The representation of the cumulative rounding error of an algorithm as a Taylor expansion of the local rounding errors. Master’s Thesis (in Finnish), Univ. Helsinki, 1970.↩︎