Sparse Gaussian Process Regression (SGPR)

Overview

In this notebook, we’ll overview how to use SGPR in which the inducing point locations are learned.

[1]:
import math
import torch
import gpytorch
import tqdm.notebook as tqdm
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline

For this example notebook, we’ll be using the elevators UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we’ll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.

Note: Running the next cell will attempt to download a ~400 KB dataset file to the current directory.

[2]:
import urllib.request
import os
from scipy.io import loadmat
from math import floor


# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)


if not smoke_test and not os.path.isfile('../elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')


if smoke_test:  # this is for running the notebook in our testing framework
    X, y = torch.randn(1000, 3), torch.randn(1000)
else:
    data = torch.Tensor(loadmat('../elevators.mat')['data'])
    X = data[:, :-1]
    X = X - X.min(0)[0]
    X = 2 * (X / X.max(0)[0]) - 1
    y = data[:, -1]


train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()
[3]:
X.size()
[3]:
torch.Size([16599, 18])

Defining the SGPR Model

We now define the GP model. For more details on the use of GP models, see our simpler examples. This model constructs a base scaled RBF kernel, and then simply wraps it in an InducingPointKernel. Other than this, everything should look the same as in the simple GP models.

[4]:
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel, InducingPointKernel
from gpytorch.distributions import MultivariateNormal

class GPRegressionModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean()
        self.base_covar_module = ScaleKernel(RBFKernel())
        self.covar_module = InducingPointKernel(self.base_covar_module, inducing_points=train_x[:500, :].clone(), likelihood=likelihood)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)
[5]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GPRegressionModel(train_x, train_y, likelihood)

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

Training the model

[6]:
training_iterations = 2 if smoke_test else 100

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

def train():
    iterator = tqdm.tqdm(range(training_iterations), desc="Train")

    for i in iterator:
        # Zero backprop gradients
        optimizer.zero_grad()
        # Get output from model
        output = model(train_x)
        # Calc loss and backprop derivatives
        loss = -mll(output, train_y)
        loss.backward()
        iterator.set_postfix(loss=loss.item())
        optimizer.step()
        torch.cuda.empty_cache()

%time train()
CPU times: user 2.7 s, sys: 852 ms, total: 3.55 s
Wall time: 3.58 s

Making Predictions

[7]:
model.eval()
likelihood.eval()
with torch.no_grad():
    preds = model.likelihood(model(test_x))
    print('Test MAE: {}'.format(torch.mean(torch.abs(preds.mean - test_y))))
    print('Test NLL: {}'.format(-preds.to_data_independent_dist().log_prob(test_y).mean().item()))
Test MAE: 0.07258129864931107
Test NLL: 0.3463870584964752
[ ]: