Did it really snow if the snow is not there at the end of the day?
Simple Neural ODE Code
This is a quick and dirty demonstration of training Neural ODEs using Jax. The paper is
import jax.numpy as jnp
import numpy as np
from jax import grad, jacfwd, jit
from matplotlib import pyplot as plt
from typing import Callable
from tqdm.notebook import tqdm
from IPython import display
Let’s first define some a simple ODE and solve/plot it using a forward Euler time step as proof of concept. Note that all the sophisticated explicit time steppers basically are combinations of Euler’s. At the same time, we will not be using any sophisticated Jax functions for now. Our ODE is of the form $z_t = \theta z$ where $\theta$ is a 2 by 2 matrix.
theta_true = jnp.array([[0.0, 1.0], [-1.0, 0.0]])
def f(u: jnp.ndarray, t: jnp.float32, theta: jnp.ndarray = theta_true) -> jnp.ndarray:
return theta @ u
def ODESolve(z_t0: jnp.ndarray, f: Callable, t0: jnp.float32, t1: jnp.float32, theta: jnp.ndarray) -> jnp.ndarray:
delta_t = 0.001
N = int((t1 - t0) / delta_t)
z_holder = jnp.array(z_t0)
for i in range(N):
z_holder += delta_t * f(z_holder, i * delta_t, theta)
return z_holder
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
First, let’s make some benchmark data and fake data with only the initial condition off. Are we able to recover the true $z_{t_0}$?
z_t0_true = jnp.array([0, 1.0])
z_t1_true = ODESolve(z_t0_true, f, 0.0, 1.0, theta_true)
z_t0_noisy = jnp.array([0, 1.1])
z_t1_noisy = ODESolve(z_t0_noisy, f, 0.0, 1.0, theta_true)
Now, let’s make a loss function, and see if we can make a gradient decent algorithm such that we can find the true initial condition.
def L(z_t1: jnp.ndarray):
return jnp.linalg.norm(z_t1 - z_t1_true) ** 2 / 2
Before proceeding, let’s note that Jax allows one to find gradients and Jacobians very easily!
# partial L/ partial z_t1
plpzt1_noisy = grad(L)(z_t1_noisy)
# Jacobians with respect to z, and theta respectively
f_z = jacfwd(f, 0)
f_theta = jacfwd(f, 2)
Now, we need to find the gradients of L and f at specific points and places. This is simply algorithm 1 of the paper
def algorithm1(theta: jnp.ndarray, t0: jnp.float32, t1: jnp.float32, z_t1: jnp.ndarray, plpzt1: jnp.ndarray):
# Define initial augmented state
s0 = jnp.concatenate((z_t1, plpzt1, jnp.zeros_like(theta).flatten()))
def aug_dynamics(u: jnp.ndarray, t: jnp.float32, theta: jnp.ndarray):
"""
u consists of z(t), a(t) and theta (4 by 4), also negative since I can't do backwards lol
"""
return -jnp.concatenate((f(u[0:2], t, theta),
- u[2:4].T @ f_z(u[0:2], t, theta) ,
(-u[2:4].T @ f_theta(u[0:2], t, theta)).flatten()))
output = ODESolve(s0, aug_dynamics, 0.0, 1.0, theta_true)
# Split the augmented data back to our regular inputs
return output[0:2], output[2:4], output[4:].reshape((2,2))
print(algorithm1(theta_true, 0.0, 1.0, z_t1_noisy, plpzt1_noisy))
(DeviceArray([1.0890653e-06, 1.1010987e+00], dtype=float32), DeviceArray([1.3050158e-07, 1.0009895e-01], dtype=float32), DeviceArray([[0.03007106, 0.03902269], [0.03902263, 0.08009262]], dtype=float32))
Note that the derivative with respect to the initial condition $\partial L/ \partial z(t_0)$ is essentially $(0, .1)$ which indicates that we have succesfully obtained the correct gradient. Note that two points is not enough to properly define a ODE and we can certainly change the dynamics to go from z_t0_noisy to z_t1_true by simply changing theta.
Let’s encapsulate the above in a gradient descent algorithm… which is a bit slow…
my_theta = jnp.array(theta_true)
learning_rate = 0.1
def train(initial_z0, initial_theta):
# Make copies because I don't know how Jax works
z_t0 = jnp.array(initial_z0)
theta = jnp.array(initial_theta)
for i in tqdm(range(5)):
# Given current initial conditions, let's
z_t1 = ODESolve(z_t0, f, 0.0, 1.0, theta)
plpzt1 = grad(L)(z_t1)
_, plpzt0, plptheta = algorithm1(theta, 0.0, 1.0, z_t1, plpzt1)
z_t0 = z_t0 - learning_rate * plpzt0
theta = theta - learning_rate * plptheta.reshape((2,2))
return z_t0, theta
print(train(z_t0_noisy, theta_true))
(DeviceArray([8.1869512e-06, 1.0662646e+00], dtype=float32), DeviceArray([[-0.00988956, 0.9871685 ], [-1.0128397 , -0.02635211]], dtype=float32))
This seems to work! Success.
Note that the above code is rather slow, so let’s see how we can use jax things to speed it up! Using the XLA, and compiled code with jit, we see that using a 3rd order solver is even faster than the Euler time step by 10 folds!
from jax import lax
@jit
def f(u: jnp.ndarray, t: jnp.float32, theta: jnp.ndarray = theta_true) -> jnp.ndarray:
return theta @ u
def ODESolve_lax(z_t0: jnp.ndarray, f: Callable, t0: jnp.float32, t1: jnp.float32, theta: jnp.ndarray) -> jnp.ndarray:
"""
We use the lax loop instead, which makes things faster so we use a third order RK3... which is still faster
"""
delta_t = 0.001
N = jnp.floor_divide(t1 - t0, delta_t).astype(jnp.int32)
z_holder = jnp.array(z_t0)
def rk3(i: jnp.int32, val: jnp.ndarray):
k1 = f(val, i * delta_t, theta)
k2 = f(val + delta_t * k1, i * delta_t + delta_t, theta)
k3 = f(val + delta_t * (1 / 4 * k1 + 1 / 4 * k2), i * delta_t + 1 / 2 * delta_t, theta)
return val + delta_t * (k1 / 6 + k2 / 6 + 2 * k3 / 3)
return lax.fori_loop(0, N, rk3, z_t0)
647 ms ± 20.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 57.1 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
print(ODESolve(z_t0_noisy, f, 0.0, 2 * jnp.pi, theta_true))
print(ODESolve_lax(z_t0_noisy, f, 0.0, 2 * jnp.pi, theta_true))
[-2.0616804e-04 1.1034613e+00] [-2.0406539e-04 1.1000000e+00]
Now that ODESolve_lax is much faster (this is only on CPU… performance should be better if we export to GPU), we can try to apply jit to other functions to make it faster, and try to fit an actual ODE with mulitple points. First let’s define a loss function, and use jit to define the Jacobians:
@jit
def L(z_t1: jnp.ndarray, z_t1_true: jnp.ndarray):
"""
Simple function; easy to jit
"""
return jnp.linalg.norm(z_t1 - z_t1_true) ** 2 / 2
f_z = jit(jacfwd(f, 0))
f_theta = jit(jacfwd(f, 2))
Now, rewriting algorithm1 with jit is fairly easy…
@jit
def algorithm1(theta: jnp.ndarray, t0: jnp.float32, t1: jnp.float32, z_t1: jnp.ndarray, plpzt1: jnp.ndarray):
# Define initial augmented state
s0 = jnp.concatenate((z_t1, plpzt1, jnp.zeros_like(theta).flatten()))
# Just an index for the length of the problem
z_size = len(z_t1)
@jit
def aug_dynamics(u: jnp.ndarray, t: jnp.float32, theta: jnp.ndarray):
"""
u consists of z(t), a(t) and theta, also negative since I can't do backwards lol
"""
z = u[0:z_size]
a = u[z_size:2 * z_size]
return -jnp.concatenate((f(z, t, theta),
- a.T @ f_z(z, t, theta) ,
(-a.T @ f_theta(z, t, theta)).flatten()))
output = ODESolve_lax(s0, aug_dynamics, t0, t1, theta_true)
return output[0:z_size], output[z_size:2 * z_size], output[2 * z_size:].reshape((z_size, z_size))
print(algorithm1(theta_true, 0.0, 1.0, z_t1_noisy, plpzt1_noisy))
(DeviceArray([1.1002590e-03, 1.1005477e+00], dtype=float32), DeviceArray([0.00010013, 0.10004898], dtype=float32), DeviceArray([[0.03002402, 0.03898256], [0.03898243, 0.07997467]], dtype=float32))
As written the above algorithm doesn’t support batching, which is a huge pain in the butt. I couldn’t get vmap from jax to work out of the box but we won’t need it here.
Now the gradient formula is much faster. Let’s make a cooler true ODE solution, and plot it
theta_true = jnp.array([[-.1, .9], [-.9, -.2]])
init_true = jnp.array([0.0, 1.0])
total_time = 4 * jnp.pi
num_data = 20
@jit
def predict(initial_point: jnp.ndarray, theta: jnp.ndarray, total_time = 4 * jnp.pi, num_data = 20):
# jnp arrays are immutable
prediction = jnp.zeros((num_data, 2))
prediction = prediction.at[0, :].set(initial_point)
for i in range(1, num_data):
prediction = prediction.at[i, :].set(ODESolve_lax(prediction[i - 1, :], f, total_time / num_data * i,
total_time / num_data * (i + 1), theta))
return jnp.array(prediction)
all_data = predict(init_true, theta_true)
# TODO: we might want to add some sort of noise in the future
plt.plot(all_data[:, 0], all_data[:, 1], '-o')
[<matplotlib.lines.Line2D at 0x7f30fc134c18>]
With the data done, we want to recover the both the dynamics and the initial condition. I think the ODE demo on the author github only recovers dynamics.
Since we care about the exact initial conditions for the above dynamics, it only makes sense that all our samples which alters the initial conditions should contain our dirty initial conditions. The following training loop does exactly that; it takes some sample of $(t_0, t_K)$ where $K$ is some random number (could be a batch) and then alters the parameters due to that.
The initial theta is very sensitive; this sort of model can’t really optimize over bifurcations very well it seems so if the model doesn’t converge, just try to run again.
# We start with some initial parameter and points
my_theta = jnp.array([[0.0, 1.0], [-1.0, 0.0]]) + np.random.normal(0, .4, size=(2,2))
initial_point = init_true + np.random.normal(0, .2, size=(2,))
# Custom gradient descent; no need for Adams in this case hopefully
learning_rate = 0.1
gradL = jit(grad(L))
batch_size = 3
epochs = 1000
for epoch in tqdm(range(epochs)):
if epoch z = predict(initial_point, my_theta)
plt.plot(z[:, 0], z[:, 1], alpha=(epoch / epochs) * .7 + 0.3 )
plt.xlim([-.9, 0.9])
plt.ylim([-1.0, 1.4])
plt.title(f"Epoch {epoch}; Loss {L(z, all_data)}")
display.clear_output(wait=True)
display.display(plt.gcf())
# Once we get close enough it's fine
if L(z, all_data) < 1e-4:
break
# Choose batch of time
times_indices = np.random.choice(np.arange(1, len(all_data)), size=batch_size, replace=False)
for ind in times_indices:
my_point = ODESolve_lax(initial_point, f, 0, ind * total_time / num_data, my_theta)
# Get gradient # Need to batch this ultimately at some point
_, plpz0, plptheta = algorithm1(
my_theta,
0.0,
ind * total_time / num_data,
my_point,
gradL(my_point, all_data[ind, :])
)
# Apply gradient
initial_point = initial_point - learning_rate * plpz0
my_theta = my_theta - learning_rate * plptheta.reshape((2,2))
plt.plot(all_data[:, 0], all_data[:, 1], '-o', linestyle='dashed', alpha=.9)
Hitting the Slopes
I recently had an epitome the other day: having fun skiing entails skiing less.
The logic is simple. A tired skier is an unhappy skier. A tired skier is a an injury-prone skier. A tired skier is a miserable skier.
Another particular benefit of skiing less is I can save money, since many ski parks offer half-day passes for a discount. By only opting for the afternoon session, I also get to experience a warmer day with less crowds since après ski tend to set in around lunch time. There’s also the fact that I could sleep in more, and not contend with the early morning rush.
I finally understand less is more now.
A Walk in the Woods
My third Bill Bryson book after a Short History of Nearly Everything and At Home. As usual, I loved his writing style and humor. I found myself laughing and snickering more while reading this book about struggling through the Appalachian Trail then the Anxious People novel.
Pros:
- Lots of facts, and his usage of vocabulary has always been strong.
- Dry, British humor
Cons:
- No sources; there is a fact about the average American walking less than ____ miles a week which I couldn’t verify.
The novel was published back in 1998, I do wonder if anything has changed since. There were quite a few sections where Bryson crucified certain governmental actions (or lack thereof) which I suspect will have gained a lot of attention due to the book/movie. It’s just so depressing sometimes to read about the accelerating decline of nature.
My Castle
Nowadays, home is not only where the heart is, but also where my office is. I spend over twenty one hours per day on average inside the confines of the four walls. This is partly why I splurged for the two bedroom instead of the one, and why I’ve been sort of obsessing over air quality recently.
Ever since college, I’ve developed some sort of allergic reaction whenever I visit my parent’s house in Florida. I’m convinced it’s due to some sort of ragweed floating around, or the sudden changes in stress leads to a floundering immune system. When there’s nobody around, I just stuck tissues up my nose… luckily, my sinus passage has so far enjoyed New Mexico… for the most part.
I’ve noticed in the afternoons that my nose are constantly flared, and that repeated vacuuming of my carpets still leads to large amounts of dust piling in the bin. The rays of sunshine streaming into my apartments laid bare the amount of dust floating around. It makes sense; the dust comes from, well, the twenty one hours I spend at home and the fact that the dry, high desert environment readily kicks up the sand and particulates from the ground. Since no home is sealed, those little particulate matter diffuse into my castle.
Ever since I got my Airmega 200M air purifier, that nose flaring has gone away for the most part. A side effect is that I have to dust my home less with that HEPA filter running. It has also made me more aware of the effects on indoor air quality of cooking and vacuuming. The filter plus my humidifier really makes my indoor air quality much better. This translates to a happier Marshall.
Now, I’m concerned about my water quality…
Anxious People
I wasn’t sure what this book was supposed to be.
It started off as a bank robbery gone awry. A few dozen pages later, it become a surprisingly heartfelt discussions on morality of intentions, and depressive thoughts. At other times, it tried its best to be a comedy (though, I don’t think it really worked for me).
In the end, I thought it tried to hard to be all three. As a mystery, I thought the “twist” was not that inventive. I did like how the the individual stories from each character played into the overarching mystery though. And also to be fair, it was arguably one of the more realistic resolutions that one can imagine.
As a comedy, it just didn’t click with me. I’m not sure whether it’s the characters or just the way my humor works… but I just didn’t laugh that much at all. Compare this to the Walk in the Woods which I’m currently reading now which has me audibly snickering.
Finally, as a character study, I didn’t much care for the characters. I found them grating, and rather unenjoyable to be around.
I can see why some people would like the book (hence the Netflix adaptation), but it’s not for me.
Zzz
Why is the term adultnapped not a thing when adults get kidnapped?
Linear Piola Transforms
I’m too lazy to convert the LaTeX to WP friendly, so here it is as a PDF.
The Invisible Life of Addie LaRue
The Faustian bargain at the heart of the novel is intriguing: our protagonist Addie is allowed to live forever, but she is not allowed to make a “mark” on the world during her life. This means that everyone she meets will forget her as soon as she leaves the room. The curse is the embodiment of “out of sight, out of mind.”
Beyond the social aspects, she cannot draw, paint or write, for those can leave marks. Photographs of her develop to be stubbornly out of focus. Even her transient footprints get wiped away remarkably quickly. She is, in society’s eye, invisible.
While intriguing to discuss the consequences (such as how does one travel internationally in this day and age without a passport if one can be forgotten instantly with no records… or the fact that I think the author could’ve spent much more time in the “meat” of the time period rather than mostly near the beginning and end), the central driving force behind Addie is her desperation to be remembered. In time, she found that she can influence artists to create art inspired by her, supposedly remarkable, face and figure. I really liked this loophole for some odd reason.
Without spoiling the story too much, she meets a… remarkably… dull man who can remember her. Character traits notwithstanding, I did very much enjoy the writing in the last few chapters of this man. Speaking too much here would spoil the ending.
Overall, solid book. Decently interesting plot points. Fun read.
Odd (or ironic?) that esoteric is an esoteric word.