This is a quick and dirty demonstration of training Neural ODEs using Jax. The paper is
import jax.numpy as jnp
import numpy as np
from jax import grad, jacfwd, jit
from matplotlib import pyplot as plt
from typing import Callable
from tqdm.notebook import tqdm
from IPython import display
Let’s first define some a simple ODE and solve/plot it using a forward Euler time step as proof of concept. Note that all the sophisticated explicit time steppers basically are combinations of Euler’s. At the same time, we will not be using any sophisticated Jax functions for now. Our ODE is of the form $z_t = \theta z$ where $\theta$ is a 2 by 2 matrix.
theta_true = jnp.array([[0.0, 1.0], [-1.0, 0.0]])
def f(u: jnp.ndarray, t: jnp.float32, theta: jnp.ndarray = theta_true) -> jnp.ndarray:
return theta @ u
def ODESolve(z_t0: jnp.ndarray, f: Callable, t0: jnp.float32, t1: jnp.float32, theta: jnp.ndarray) -> jnp.ndarray:
delta_t = 0.001
N = int((t1 - t0) / delta_t)
z_holder = jnp.array(z_t0)
for i in range(N):
z_holder += delta_t * f(z_holder, i * delta_t, theta)
return z_holder
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
First, let’s make some benchmark data and fake data with only the initial condition off. Are we able to recover the true $z_{t_0}$?
z_t0_true = jnp.array([0, 1.0])
z_t1_true = ODESolve(z_t0_true, f, 0.0, 1.0, theta_true)
z_t0_noisy = jnp.array([0, 1.1])
z_t1_noisy = ODESolve(z_t0_noisy, f, 0.0, 1.0, theta_true)
Now, let’s make a loss function, and see if we can make a gradient decent algorithm such that we can find the true initial condition.
def L(z_t1: jnp.ndarray):
return jnp.linalg.norm(z_t1 - z_t1_true) ** 2 / 2
Before proceeding, let’s note that Jax allows one to find gradients and Jacobians very easily!
# partial L/ partial z_t1
plpzt1_noisy = grad(L)(z_t1_noisy)
# Jacobians with respect to z, and theta respectively
f_z = jacfwd(f, 0)
f_theta = jacfwd(f, 2)
Now, we need to find the gradients of L and f at specific points and places. This is simply algorithm 1 of the paper
def algorithm1(theta: jnp.ndarray, t0: jnp.float32, t1: jnp.float32, z_t1: jnp.ndarray, plpzt1: jnp.ndarray):
# Define initial augmented state
s0 = jnp.concatenate((z_t1, plpzt1, jnp.zeros_like(theta).flatten()))
def aug_dynamics(u: jnp.ndarray, t: jnp.float32, theta: jnp.ndarray):
"""
u consists of z(t), a(t) and theta (4 by 4), also negative since I can't do backwards lol
"""
return -jnp.concatenate((f(u[0:2], t, theta),
- u[2:4].T @ f_z(u[0:2], t, theta) ,
(-u[2:4].T @ f_theta(u[0:2], t, theta)).flatten()))
output = ODESolve(s0, aug_dynamics, 0.0, 1.0, theta_true)
# Split the augmented data back to our regular inputs
return output[0:2], output[2:4], output[4:].reshape((2,2))
print(algorithm1(theta_true, 0.0, 1.0, z_t1_noisy, plpzt1_noisy))
(DeviceArray([1.0890653e-06, 1.1010987e+00], dtype=float32), DeviceArray([1.3050158e-07, 1.0009895e-01], dtype=float32), DeviceArray([[0.03007106, 0.03902269], [0.03902263, 0.08009262]], dtype=float32))
Note that the derivative with respect to the initial condition $\partial L/ \partial z(t_0)$ is essentially $(0, .1)$ which indicates that we have succesfully obtained the correct gradient. Note that two points is not enough to properly define a ODE and we can certainly change the dynamics to go from z_t0_noisy to z_t1_true by simply changing theta.
Let’s encapsulate the above in a gradient descent algorithm… which is a bit slow…
my_theta = jnp.array(theta_true)
learning_rate = 0.1
def train(initial_z0, initial_theta):
# Make copies because I don't know how Jax works
z_t0 = jnp.array(initial_z0)
theta = jnp.array(initial_theta)
for i in tqdm(range(5)):
# Given current initial conditions, let's
z_t1 = ODESolve(z_t0, f, 0.0, 1.0, theta)
plpzt1 = grad(L)(z_t1)
_, plpzt0, plptheta = algorithm1(theta, 0.0, 1.0, z_t1, plpzt1)
z_t0 = z_t0 - learning_rate * plpzt0
theta = theta - learning_rate * plptheta.reshape((2,2))
return z_t0, theta
print(train(z_t0_noisy, theta_true))
(DeviceArray([8.1869512e-06, 1.0662646e+00], dtype=float32), DeviceArray([[-0.00988956, 0.9871685 ], [-1.0128397 , -0.02635211]], dtype=float32))
This seems to work! Success.
Note that the above code is rather slow, so let’s see how we can use jax things to speed it up! Using the XLA, and compiled code with jit, we see that using a 3rd order solver is even faster than the Euler time step by 10 folds!
from jax import lax
@jit
def f(u: jnp.ndarray, t: jnp.float32, theta: jnp.ndarray = theta_true) -> jnp.ndarray:
return theta @ u
def ODESolve_lax(z_t0: jnp.ndarray, f: Callable, t0: jnp.float32, t1: jnp.float32, theta: jnp.ndarray) -> jnp.ndarray:
"""
We use the lax loop instead, which makes things faster so we use a third order RK3... which is still faster
"""
delta_t = 0.001
N = jnp.floor_divide(t1 - t0, delta_t).astype(jnp.int32)
z_holder = jnp.array(z_t0)
def rk3(i: jnp.int32, val: jnp.ndarray):
k1 = f(val, i * delta_t, theta)
k2 = f(val + delta_t * k1, i * delta_t + delta_t, theta)
k3 = f(val + delta_t * (1 / 4 * k1 + 1 / 4 * k2), i * delta_t + 1 / 2 * delta_t, theta)
return val + delta_t * (k1 / 6 + k2 / 6 + 2 * k3 / 3)
return lax.fori_loop(0, N, rk3, z_t0)
647 ms ± 20.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 57.1 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
print(ODESolve(z_t0_noisy, f, 0.0, 2 * jnp.pi, theta_true))
print(ODESolve_lax(z_t0_noisy, f, 0.0, 2 * jnp.pi, theta_true))
[-2.0616804e-04 1.1034613e+00] [-2.0406539e-04 1.1000000e+00]
Now that ODESolve_lax is much faster (this is only on CPU… performance should be better if we export to GPU), we can try to apply jit to other functions to make it faster, and try to fit an actual ODE with mulitple points. First let’s define a loss function, and use jit to define the Jacobians:
@jit
def L(z_t1: jnp.ndarray, z_t1_true: jnp.ndarray):
"""
Simple function; easy to jit
"""
return jnp.linalg.norm(z_t1 - z_t1_true) ** 2 / 2
f_z = jit(jacfwd(f, 0))
f_theta = jit(jacfwd(f, 2))
Now, rewriting algorithm1 with jit is fairly easy…
@jit
def algorithm1(theta: jnp.ndarray, t0: jnp.float32, t1: jnp.float32, z_t1: jnp.ndarray, plpzt1: jnp.ndarray):
# Define initial augmented state
s0 = jnp.concatenate((z_t1, plpzt1, jnp.zeros_like(theta).flatten()))
# Just an index for the length of the problem
z_size = len(z_t1)
@jit
def aug_dynamics(u: jnp.ndarray, t: jnp.float32, theta: jnp.ndarray):
"""
u consists of z(t), a(t) and theta, also negative since I can't do backwards lol
"""
z = u[0:z_size]
a = u[z_size:2 * z_size]
return -jnp.concatenate((f(z, t, theta),
- a.T @ f_z(z, t, theta) ,
(-a.T @ f_theta(z, t, theta)).flatten()))
output = ODESolve_lax(s0, aug_dynamics, t0, t1, theta_true)
return output[0:z_size], output[z_size:2 * z_size], output[2 * z_size:].reshape((z_size, z_size))
print(algorithm1(theta_true, 0.0, 1.0, z_t1_noisy, plpzt1_noisy))
(DeviceArray([1.1002590e-03, 1.1005477e+00], dtype=float32), DeviceArray([0.00010013, 0.10004898], dtype=float32), DeviceArray([[0.03002402, 0.03898256], [0.03898243, 0.07997467]], dtype=float32))
As written the above algorithm doesn’t support batching, which is a huge pain in the butt. I couldn’t get vmap from jax to work out of the box but we won’t need it here.
Now the gradient formula is much faster. Let’s make a cooler true ODE solution, and plot it
theta_true = jnp.array([[-.1, .9], [-.9, -.2]])
init_true = jnp.array([0.0, 1.0])
total_time = 4 * jnp.pi
num_data = 20
@jit
def predict(initial_point: jnp.ndarray, theta: jnp.ndarray, total_time = 4 * jnp.pi, num_data = 20):
# jnp arrays are immutable
prediction = jnp.zeros((num_data, 2))
prediction = prediction.at[0, :].set(initial_point)
for i in range(1, num_data):
prediction = prediction.at[i, :].set(ODESolve_lax(prediction[i - 1, :], f, total_time / num_data * i,
total_time / num_data * (i + 1), theta))
return jnp.array(prediction)
all_data = predict(init_true, theta_true)
# TODO: we might want to add some sort of noise in the future
plt.plot(all_data[:, 0], all_data[:, 1], '-o')
[<matplotlib.lines.Line2D at 0x7f30fc134c18>]
With the data done, we want to recover the both the dynamics and the initial condition. I think the ODE demo on the author github only recovers dynamics.
Since we care about the exact initial conditions for the above dynamics, it only makes sense that all our samples which alters the initial conditions should contain our dirty initial conditions. The following training loop does exactly that; it takes some sample of $(t_0, t_K)$ where $K$ is some random number (could be a batch) and then alters the parameters due to that.
The initial theta is very sensitive; this sort of model can’t really optimize over bifurcations very well it seems so if the model doesn’t converge, just try to run again.
# We start with some initial parameter and points
my_theta = jnp.array([[0.0, 1.0], [-1.0, 0.0]]) + np.random.normal(0, .4, size=(2,2))
initial_point = init_true + np.random.normal(0, .2, size=(2,))
# Custom gradient descent; no need for Adams in this case hopefully
learning_rate = 0.1
gradL = jit(grad(L))
batch_size = 3
epochs = 1000
for epoch in tqdm(range(epochs)):
if epoch z = predict(initial_point, my_theta)
plt.plot(z[:, 0], z[:, 1], alpha=(epoch / epochs) * .7 + 0.3 )
plt.xlim([-.9, 0.9])
plt.ylim([-1.0, 1.4])
plt.title(f"Epoch {epoch}; Loss {L(z, all_data)}")
display.clear_output(wait=True)
display.display(plt.gcf())
# Once we get close enough it's fine
if L(z, all_data) < 1e-4:
break
# Choose batch of time
times_indices = np.random.choice(np.arange(1, len(all_data)), size=batch_size, replace=False)
for ind in times_indices:
my_point = ODESolve_lax(initial_point, f, 0, ind * total_time / num_data, my_theta)
# Get gradient # Need to batch this ultimately at some point
_, plpz0, plptheta = algorithm1(
my_theta,
0.0,
ind * total_time / num_data,
my_point,
gradL(my_point, all_data[ind, :])
)
# Apply gradient
initial_point = initial_point - learning_rate * plpz0
my_theta = my_theta - learning_rate * plptheta.reshape((2,2))
plt.plot(all_data[:, 0], all_data[:, 1], '-o', linestyle='dashed', alpha=.9)