Sparse Gaussian Process Regression (SGPR)

Overview

In this notebook, we’ll overview how to use SGPR in which the inducing point locations are learned.

[1]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline

For this example notebook, we’ll be using the elevators UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we’ll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.

Note: Running the next cell will attempt to download a ~400 KB dataset file to the current directory.

[2]:
import urllib.request
import os
from scipy.io import loadmat
from math import floor


# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)


if not smoke_test and not os.path.isfile('../elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')


if smoke_test:  # this is for running the notebook in our testing framework
    X, y = torch.randn(1000, 3), torch.randn(1000)
else:
    data = torch.Tensor(loadmat('../elevators.mat')['data'])
    X = data[:, :-1]
    X = X - X.min(0)[0]
    X = 2 * (X / X.max(0)[0]) - 1
    y = data[:, -1]


train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()
[3]:
X.size()
[3]:
torch.Size([16599, 18])

Defining the SGPR Model

We now define the GP model. For more details on the use of GP models, see our simpler examples. This model constructs a base scaled RBF kernel, and then simply wraps it in an InducingPointKernel. Other than this, everything should look the same as in the simple GP models.

[4]:
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel, InducingPointKernel
from gpytorch.distributions import MultivariateNormal

class GPRegressionModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean()
        self.base_covar_module = ScaleKernel(RBFKernel())
        self.covar_module = InducingPointKernel(self.base_covar_module, inducing_points=train_x[:500, :], likelihood=likelihood)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)
[5]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GPRegressionModel(train_x, train_y, likelihood)

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

Training the model

[6]:
training_iterations = 2 if smoke_test else 50

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

def train():
    for i in range(training_iterations):
        # Zero backprop gradients
        optimizer.zero_grad()
        # Get output from model
        output = model(train_x)
        # Calc loss and backprop derivatives
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
        optimizer.step()
        torch.cuda.empty_cache()

# See dkl_mnist.ipynb for explanation of this flag
%time train()
Iter 1/50 - Loss: 0.794
Iter 2/50 - Loss: 0.782
Iter 3/50 - Loss: 0.770
Iter 4/50 - Loss: 0.758
Iter 5/50 - Loss: 0.746
Iter 6/50 - Loss: 0.734
Iter 7/50 - Loss: 0.721
Iter 8/50 - Loss: 0.708
Iter 9/50 - Loss: 0.695
Iter 10/50 - Loss: 0.681
Iter 11/50 - Loss: 0.667
Iter 12/50 - Loss: 0.654
Iter 13/50 - Loss: 0.641
Iter 14/50 - Loss: 0.626
Iter 15/50 - Loss: 0.613
Iter 16/50 - Loss: 0.598
Iter 17/50 - Loss: 0.584
Iter 18/50 - Loss: 0.571
Iter 19/50 - Loss: 0.555
Iter 20/50 - Loss: 0.541
Iter 21/50 - Loss: 0.526
Iter 22/50 - Loss: 0.510
Iter 23/50 - Loss: 0.495
Iter 24/50 - Loss: 0.481
Iter 25/50 - Loss: 0.465
Iter 26/50 - Loss: 0.449
Iter 27/50 - Loss: 0.435
Iter 28/50 - Loss: 0.417
Iter 29/50 - Loss: 0.401
Iter 30/50 - Loss: 0.384
Iter 31/50 - Loss: 0.369
Iter 32/50 - Loss: 0.351
Iter 33/50 - Loss: 0.336
Iter 34/50 - Loss: 0.319
Iter 35/50 - Loss: 0.303
Iter 36/50 - Loss: 0.286
Iter 37/50 - Loss: 0.269
Iter 38/50 - Loss: 0.253
Iter 39/50 - Loss: 0.236
Iter 40/50 - Loss: 0.217
Iter 41/50 - Loss: 0.200
Iter 42/50 - Loss: 0.181
Iter 43/50 - Loss: 0.167
Iter 44/50 - Loss: 0.149
Iter 45/50 - Loss: 0.132
Iter 46/50 - Loss: 0.112
Iter 47/50 - Loss: 0.096
Iter 48/50 - Loss: 0.078
Iter 49/50 - Loss: 0.061
Iter 50/50 - Loss: 0.044
CPU times: user 2min 47s, sys: 7.87 s, total: 2min 55s
Wall time: 34.6 s

Making Predictions

The next cell makes predictions with SKIP. We use the same max_root_decomposition size, and we also demonstrate increasing the max preconditioner size. Increasing the preconditioner size on this dataset is not necessary, but can make a big difference in final test performance, and is often preferable to increasing the number of CG iterations if you can afford the space.

[7]:
model.eval()
likelihood.eval()
with gpytorch.settings.max_preconditioner_size(10), torch.no_grad():
    preds = model(test_x)
[8]:
print('Test MAE: {}'.format(torch.mean(torch.abs(preds.mean - test_y))))
Test MAE: 0.07271435856819153
[ ]: