Converting Exact GP Models to TorchScript¶
In this notebook, we’ll demonstrate converting an Exact GP model to TorchScript. In general, this is the same as for standard PyTorch models where we’ll use
torch.jit.trace, but there are two pecularities to keep in mind for GPyTorch:
The first time you make predictions with a GPyTorch model (exact or approximate), we cache certain computations. These computations can’t be traced, but the results of them can be. Therefore, we’ll need to pass data through the untraced model once, and then trace the model.
For exact GPs, we can’t trace models unless
gpytorch.settings.fast_pred_varis used. This is a technical issue that may not be possible to overcome due to limitations on what can be traced in PyTorch; however, if you really need to trace a GP but can’t use the above setting, open an issue so we have visibility on there being demand for this.
You can’t trace models that return Distribution objects. Therefore, we’ll write a simple wrapper than unpacks the MultivariateNormal that our GPs return in to just a mean and variance tensor.
Define and train an exact GP¶
In the next cell, we define some data, define a GP model and train it. Nothing new here – pretty much just move on to the next cell after this one.
import math import torch import gpytorch from matplotlib import pyplot as plt %matplotlib inline %load_ext autoreload %autoreload 2 # Training data is 100 points in [0,1] inclusive regularly spaced train_x = torch.linspace(0, 1, 100) # True function is sin(2*pi*x) with Gaussian noise train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2 # We will use the simplest form of GP model, exact inference class ExactGPModel(gpytorch.models.ExactGP): def __init__(self, train_x, train_y, likelihood): super(ExactGPModel, self).__init__(train_x, train_y, likelihood) self.mean_module = gpytorch.means.ConstantMean() self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) def forward(self, x): mean_x = self.mean_module(x) covar_x = self.covar_module(x) return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) # initialize likelihood and model likelihood = gpytorch.likelihoods.GaussianLikelihood() model = ExactGPModel(train_x, train_y, likelihood) # this is for running the notebook in our testing framework import os smoke_test = ('CI' in os.environ) training_iter = 2 if smoke_test else 50 # Find optimal model hyperparameters model.train() likelihood.train() # Use the adam optimizer optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters # "Loss" for GPs - the marginal log likelihood mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) for i in range(training_iter): # Zero gradients from previous iteration optimizer.zero_grad() # Output from model output = model(train_x) # Calc loss and backprop gradients loss = -mll(output, train_y) loss.backward() optimizer.step()
Trace the Model¶
In the next cell, we trace our GP model. To overcome the fact that we can’t trace Modules that return Distributions, we write a wrapper Module that unpacks the GP output in to a mean and variance.
Additionally, we’ll need to run with the
gpytorch.settings.trace_mode setting enabled, because PyTorch can’t trace custom autograd Functions. Note that this results in some inefficiencies.
Then, before calling
torch.jit.trace we first call the model on
test_x. This step is required, as it does some precomputation using torch functionality that cannot be traced.
class MeanVarModelWrapper(torch.nn.Module): def __init__(self, gp): super().__init__() self.gp = gp def forward(self, x): output_dist = self.gp(x) return output_dist.mean, output_dist.variance with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.trace_mode(): model.eval() test_x = torch.linspace(0, 1, 51) pred = model(test_x) # Do precomputation traced_model = torch.jit.trace(MeanVarModelWrapper(model), test_x)
Compare Predictions from TorchScript model and Torch model¶
with torch.no_grad(): traced_mean, traced_var = traced_model(test_x) print(torch.norm(traced_mean - pred.mean)) print(torch.norm(traced_var - pred.variance))