The module method register_full_backward_hook
is somewhat esoteric. The user is suppose to provide a function hook(module, grad_input, grad_output) -> tuple(Tensor) or None
which will be executed “every time the gradients with respect to a module are computed.” But what is actually grad_input
and grad_output
? I think one of the simpler ways is to view it from an adjoint formulation.
For sake of simplicity, we consider a $N$ layer neural network with layers consisting of the form
$$
x_{n+1} = \Phi_n(x_n)
$$
where $x_n$ is the input into the $n$th layer, and $\Phi_n$ is some function with parameters which represent the $n$th layer. Suppose now that we consider a $L^2$ loss
$$
\min \frac{1}{2} ||y – x_N||^2
$$
where $y$ is the sample labels.
We can actually rewrite the above into a constrained minimization point of view
$$
\min \frac{1}{2} ||y – x_N||^2
$$
such that $x_N = \Phi_{N-1}(x_{N-1}), \ldots, x_1 = \Phi_0(x_0)$ for some sample data $x_0$.
From a math point of view, we can use Lagrange multipliers
$$
\mathcal L = \frac{1}{2} ||y – x_N||^2 + \sum_{i=0}^{N-1} \langle \lambda_{i+1} ,\Phi_{i}(x_{i}) – x_{i+1} \rangle.
$$
where $\langle \cdot, \cdot \rangle$ is simply the dot product. If we take the gradient with respect to $\lambda_1, \ldots, \lambda_{N}$ and set it equal to zero, we obtain our forward dynamics. This is pretty straightforward and actually reflects the register_forward_hook
.
The real fun part is when we take the gradient with respect to the variables $x_1, \ldots, x_N$:
\begin{align*}
(y – x_N) – \lambda_{N}^T &= 0 \\
\lambda_{N}^T \nabla \Phi_{N-1}(x_{N-1}) – \lambda_{N-1}^T &= 0 \\
\lambda_{N-1}^T \nabla \Phi_{N-2}(x_{N-2}) – \lambda_{N-3}^T &= 0 \\
\vdots &= \vdots.
\end{align*}
In particular, we can rewrite the above as a backwards dynamics on the so called “adjoint” variable with initial conditions $\lambda_N = (y – x_n)^T$ with dynamics $\lambda_{i-1} = (\nabla \Phi_{i-1}(x_{i-1}))^T \lambda_{i}$. Note that we’re pretty liberal with our notation on the gradients/transposes, and some dimensions errors may have occurred. As it turns out, this adjoint variable is the component being calculated by the register_full_backward_hook
!
But before we discuss use some code and see this in practice, why is useful? Why does PyTorch calculate $\lambda_i$ when doing the backward pass? We introduce one more bit of notation: let $\theta_n$ be the parameters of $\Phi_n$. The backwards pass requires us to calculate $\frac{\partial \mathcal L}{\partial \theta_n}$ for each $n$. Thus,
$$
\frac{\partial \mathcal L}{\partial \theta_n} = \frac{\partial \mathcal L}{\partial x_n} \frac{\partial x_n}{\partial \theta_n}.
$$
The second term is dependent on the layer type, and can be easily calculated, but the first term is simply the adjoint variable! Thus this adjoint term can be utilized to calculate the gradient.
Now with all the theory out of the way, let’s do a simple example in PyTorch to see that practice matches theory. For simplicity, we consider the following linear layers
$$
x_{n+1} = (A_n + I)x_n
$$
where $A, I$ are $k$ by $k$ matrices, and $x_n$ is the input from the previous layer. While the final output can simply be represented by a single matrix-vector product, we will stick with this formulation. The adjoint variable is thus now just $A_n + I$ transposed. The following code block really shows that the above derivation is true.
Running it results in, for example, something like
Hence the values of the adjoint matches whatever is being calculated in the hooks. Note that in the explicit adjoint calculation, we calculate an additional step which is not there in the hook code.