PyTorch and register_full_backward_hook

The module method register_full_backward_hook is somewhat esoteric. The user is suppose to provide a function hook(module, grad_input, grad_output) -> tuple(Tensor) or None which will be executed “every time the gradients with respect to a module are computed.” But what is actually grad_input and grad_output? I think one of the simpler ways is to view it from an adjoint formulation.

For sake of simplicity, we consider a $N$ layer neural network with layers consisting of the form
$$
x_{n+1} = \Phi_n(x_n)
$$
where $x_n$ is the input into the $n$th layer, and $\Phi_n$ is some function with parameters which represent the $n$th layer. Suppose now that we consider a $L^2$ loss
$$
\min \frac{1}{2} ||y – x_N||^2
$$
where $y$ is the sample labels.

We can actually rewrite the above into a constrained minimization point of view
$$
\min \frac{1}{2} ||y – x_N||^2
$$
such that $x_N = \Phi_{N-1}(x_{N-1}), \ldots, x_1 = \Phi_0(x_0)$ for some sample data $x_0$.

From a math point of view, we can use Lagrange multipliers
$$
\mathcal L = \frac{1}{2} ||y – x_N||^2 + \sum_{i=0}^{N-1} \langle \lambda_{i+1} ,\Phi_{i}(x_{i}) – x_{i+1} \rangle.
$$
where $\langle \cdot, \cdot \rangle$ is simply the dot product. If we take the gradient with respect to $\lambda_1, \ldots, \lambda_{N}$ and set it equal to zero, we obtain our forward dynamics. This is pretty straightforward and actually reflects the register_forward_hook.

The real fun part is when we take the gradient with respect to the variables $x_1, \ldots, x_N$:
\begin{align*}
(y – x_N) – \lambda_{N}^T &= 0 \\
\lambda_{N}^T \nabla \Phi_{N-1}(x_{N-1}) – \lambda_{N-1}^T &= 0 \\
\lambda_{N-1}^T \nabla \Phi_{N-2}(x_{N-2}) – \lambda_{N-3}^T &= 0 \\
\vdots &= \vdots.
\end{align*}
In particular, we can rewrite the above as a backwards dynamics on the so called “adjoint” variable with initial conditions $\lambda_N = (y – x_n)^T$ with dynamics $\lambda_{i-1} = (\nabla \Phi_{i-1}(x_{i-1}))^T \lambda_{i}$. Note that we’re pretty liberal with our notation on the gradients/transposes, and some dimensions errors may have occurred. As it turns out, this adjoint variable is the component being calculated by the register_full_backward_hook!

But before we discuss use some code and see this in practice, why is useful? Why does PyTorch calculate $\lambda_i$ when doing the backward pass? We introduce one more bit of notation: let $\theta_n$ be the parameters of $\Phi_n$. The backwards pass requires us to calculate $\frac{\partial \mathcal L}{\partial \theta_n}$ for each $n$. Thus,
$$
\frac{\partial \mathcal L}{\partial \theta_n} = \frac{\partial \mathcal L}{\partial x_n} \frac{\partial x_n}{\partial \theta_n}.
$$
The second term is dependent on the layer type, and can be easily calculated, but the first term is simply the adjoint variable! Thus this adjoint term can be utilized to calculate the gradient.

Now with all the theory out of the way, let’s do a simple example in PyTorch to see that practice matches theory. For simplicity, we consider the following linear layers
$$
x_{n+1} = (A_n + I)x_n
$$
where $A, I$ are $k$ by $k$ matrices, and $x_n$ is the input from the previous layer. While the final output can simply be represented by a single matrix-vector product, we will stick with this formulation. The adjoint variable is thus now just $A_n + I$ transposed. The following code block really shows that the above derivation is true.

import torch
import torch.nn as nn

# We define each layer; corresponds to \Phi
class Layer(nn.Module):
def __init__(self):
super(Layer, self).__init__()
# Just some random matrix
self.A = nn.Parameter(
torch.tensor([[1.0, -2.0, -1.0], [-2.0, 1.0, -2.0], [-1.0, -2.0, 1]]) / 4
)
def forward(self, u):
return u + self.A @ u

# The model is the full with all layers
class Model(nn.Module):
def __init__(self, n=3):
super(Model, self).__init__()
self.layers = nn.ModuleList([Layer() for _ in range(n)])
def forward(self, u):
for layer in self.layers:
u = layer(u)
return u

# This is the adjoint; just the transpose in this case
class Adjoint(nn.Module):
def __init__(self, forward_model):
super(Adjoint, self).__init__()
self.layers = forward_model.layers
def forward(self, u):
print(f'Input: {u=}')
for layer in self.layers:
print(u, end=' \t')
u = u + layer.A.T @ u
print(u)
return u

# Generate random data and label
u0 = torch.randn(3)
y = torch.randn(3)
model = Model()
model_adjoint = Adjoint(model)

# Define the hook function; hook must be of this form
def hook_fn(module, grad_input, grad_output):
print(f'{grad_output=}, {grad_input=}')

# Register the hook function to each encoder layer
hooks = []
for i, layer in enumerate(model.layers):
hook = layer.register_full_backward_hook(hook_fn)
hooks.append(hook)
# Perform a forward pass
out = model(u0)
loss = 0.5 * torch.norm(out - y) ** 2
loss.backward()
print(model_adjoint(out - y))
# Remove hooks;
for hook in hooks:
hook.remove()

Running it results in, for example, something like

grad_output=(tensor([-3.4323,  2.1327,  0.9475]),), grad_input=(tensor([-5.5936,  3.9083,  0.9761]),)
grad_output=(tensor([-5.5936, 3.9083, 0.9761]),), grad_input=(tensor([-9.1902, 7.1942, 0.6644]),)
grad_output=(tensor([-9.1902, 7.1942, 0.6644]),), grad_input=(None,)
Input: u=tensor([-3.4323, 2.1327, 0.9475], grad_fn=<SubBackward0>)
tensor([-3.4323, 2.1327, 0.9475], grad_fn=<SubBackward0>) tensor([-5.5936, 3.9083, 0.9761], grad_fn=<AddBackward0>)
tensor([-5.5936, 3.9083, 0.9761], grad_fn=<AddBackward0>) tensor([-9.1902, 7.1942, 0.6644], grad_fn=<AddBackward0>)
tensor([-9.1902, 7.1942, 0.6644], grad_fn=<AddBackward0>) tensor([-15.2510, 13.2556, -0.4691], grad_fn=<AddBackward0>)
tensor([-15.2510, 13.2556, -0.4691], grad_fn=<AddBackward0>)

Hence the values of the adjoint matches whatever is being calculated in the hooks. Note that in the explicit adjoint calculation, we calculate an additional step which is not there in the hook code.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.