Converting Exact GP Models to TorchScript

In this notebook, we’ll demonstrate converting an Exact GP model to TorchScript. In general, this is the same as for standard PyTorch models where we’ll use torch.jit.trace, but there are two pecularities to keep in mind for GPyTorch:

  1. The first time you make predictions with a GPyTorch model (exact or approximate), we cache certain computations. These computations can’t be traced, but the results of them can be. Therefore, we’ll need to pass data through the untraced model once, and then trace the model.
  2. For exact GPs, we can’t trace models unless gpytorch.settings.fast_pred_var is used. This is a technical issue that may not be possible to overcome due to limitations on what can be traced in PyTorch; however, if you really need to trace a GP but can’t use the above setting, open an issue so we have visibility on there being demand for this.
  3. You can’t trace models that return Distribution objects. Therefore, we’ll write a simple wrapper than unpacks the MultivariateNormal that our GPs return in to just a mean and variance tensor.

Define and train an exact GP

In the next cell, we define some data, define a GP model and train it. Nothing new here – pretty much just move on to the next cell after this one.

[1]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

# Training data is 100 points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, 100)
# True function is sin(2*pi*x) with Gaussian noise
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2

# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iter = 2 if smoke_test else 50


# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iter):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(train_x)
    # Calc loss and backprop gradients
    loss = -mll(output, train_y)
    loss.backward()
    optimizer.step()

Trace the Model

In the next cell, we trace our GP model. To overcome the fact that we can’t trace Modules that return Distributions, we write a wrapper Module that unpacks the GP output in to a mean and variance.

Additionally, we’ll need to run with the gpytorch.settings.trace_mode setting enabled, because PyTorch can’t trace custom autograd Functions. Note that this results in some inefficiencies.

Then, before calling torch.jit.trace we first call the model on test_x. This step is required, as it does some precomputation using torch functionality that cannot be traced.

[8]:
class MeanVarModelWrapper(torch.nn.Module):
    def __init__(self, gp):
        super().__init__()
        self.gp = gp

    def forward(self, x):
        output_dist = self.gp(x)
        return output_dist.mean, output_dist.variance

with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.trace_mode():
    model.eval()
    test_x = torch.linspace(0, 1, 51)
    pred = model(test_x)  # Do precomputation
    traced_model = torch.jit.trace(MeanVarModelWrapper(model), test_x)

Compare Predictions from TorchScript model and Torch model

[6]:
with torch.no_grad():
    traced_mean, traced_var = traced_model(test_x)

print(torch.norm(traced_mean - pred.mean))
print(torch.norm(traced_var - pred.variance))
tensor(0.)
tensor(0.)
[7]:
traced_model.save('traced_exact_gp.pt')
[ ]: