Interpolation in Tensorflow

I recently had to interpolate some data in the setting of training some machine learning code, coded in Tensorflow. It turns out that the interpolator doesn’t work with the compiled Tensorflow functions, which is usually recommended for faster execution.
This is quite annoying, but has a simple solution using the

numpy_function

in Tensorflow.

import tensorflow as tf
import numpy as np
from scipy.interpolate import NearestNDInterpolator

x = np.linspace(0, 10, 10)
y = np.linspace(0, 10, 10)
z = np.sin(x + y)
interp = NearestNDInterpolator(list(zip(x, y)), z)
points = tf.constant([[1, 2],
[3, 4],
[5, 6]], dtype=tf.float64)
@tf.function
def test(x):
return interp(x)
@tf.function
def test_2(x):
return tf.numpy_function(interp, inp=[x], Tout=tf.float64)
print(test_2(points)) # Works
print(test(points)) # Doesn't work

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.