I recently had to interpolate some data in the setting of training some machine learning code, coded in Tensorflow. It turns out that the interpolator doesn’t work with the compiled Tensorflow functions, which is usually recommended for faster execution.
This is quite annoying, but has a simple solution using the
numpy_function
in Tensorflow.
import tensorflow as tf
import numpy as np
from scipy.interpolate import NearestNDInterpolator
x = np.linspace(0, 10, 10)
y = np.linspace(0, 10, 10)
z = np.sin(x + y)
interp = NearestNDInterpolator(list(zip(x, y)), z)
points = tf.constant([[1, 2],
[3, 4],
[5, 6]], dtype=tf.float64)
@tf.function
def test(x):
return interp(x)
@tf.function
def test_2(x):
return tf.numpy_function(interp, inp=[x], Tout=tf.float64)
print(test_2(points)) # Works
print(test(points)) # Doesn't work