Stochastic Variational GP Regression with Contour Integral Quadrature

Overview

This notebook demonstrates how to perform stochastic variational GP regression using contour integral quadrature (CIQ) with msMINRES as described in Pleiss et al., 2020. Contour integral quadrature can be used in place of standard SVGP when:

  • There are many inducing points (e.g. M > 5000)
  • The inducing points have special structure (e.g. lie on a grid)

We’ll give an overview of how to use CIQ-SVGP stochastic variational regression ((https://arxiv.org/pdf/1411.2005.pdf)) to rapidly train using minibatches on the 3droad UCI dataset.

[1]:
import tqdm
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline
[2]:
import urllib.request
import os
from scipy.io import loadmat
from math import floor


# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)


if not smoke_test and not os.path.isfile('../3droad.mat'):
    print('Downloading \'3droad\' UCI dataset...')
    urllib.request.urlretrieve('https://www.dropbox.com/s/f6ow1i59oqx05pl/3droad.mat?dl=1', '../3droad.mat')

if smoke_test:  # this is for running the notebook in our testing framework
    X, y = torch.randn(10, 2), torch.randn(10)
else:
    data = torch.Tensor(loadmat('../3droad.mat')['data'])
    X = data[:, :-2]
    X = X - X.min(0)[0]
    X = 2 * (X / X.max(0)[0]) - 1
    y = data[:, -1]
    y.sub_(y.mean(0)).div_(y.std(0))

    # Let's subsample the data
    indices = torch.randperm(X.size(0))[:10000]
    X = X[indices]
    y = y[indices]


train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()

DataLoaders with CIQ-SVGP

CIQ offers computational speedups only when the minibatch size is much smaller than the number of inducing points. We find that a minibatch size of 256 often works well.

[3]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
# Smaller batch sizes are better for CIQ

test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

Number of inducing points

CIQ offers computational speedups when there are lots of inducing points. Here, we are choosing 2000 inducing points.

[4]:
inducing_points = train_x[torch.randperm(train_x.size(0))[:2000]]

CIQ - SVGP models

To use contour integral quadrature, simply replace VariationalStrategy with CiqVariationalStrategy.

In this example, we are using a NaturalVariationalStrategy, as CIQ works best with natural gradient descent. (See the NGD tutorial for more details.

[5]:
class GPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = gpytorch.variational.NaturalVariationalDistribution(inducing_points.size(0))
        variational_strategy = gpytorch.variational.CiqVariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super(GPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=2)
        )
        self.covar_module.base_kernel.initialize(lengthscale=0.01)  # Specific to the 3droad dataset

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


model = GPModel(inducing_points=inducing_points)
likelihood = gpytorch.likelihoods.GaussianLikelihood()

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()
[6]:
variational_ngd_optimizer = gpytorch.optim.NGD(model.variational_parameters(), num_data=train_y.size(0), lr=0.01)

hyperparameter_optimizer = torch.optim.Adam([
    {'params': model.hyperparameters()},
    {'params': likelihood.parameters()},
], lr=0.01)
[7]:
model.train()
likelihood.train()
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))

num_epochs = 1 if smoke_test else 4
epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc="Epoch")
for i in epochs_iter:
    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc="Minibatch", leave=False)

    for x_batch, y_batch in minibatch_iter:
        variational_ngd_optimizer.zero_grad()
        hyperparameter_optimizer.zero_grad()
        output = model(x_batch)
        loss = -mll(output, y_batch)
        minibatch_iter.set_postfix(loss=loss.item())
        loss.backward()
        variational_ngd_optimizer.step()
        hyperparameter_optimizer.step()

[8]:
model.eval()
likelihood.eval()
means = torch.tensor([0.])
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        preds = model(x_batch)
        means = torch.cat([means, preds.mean.cpu()])
means = means[1:]
print('Test MAE: {}'.format(torch.mean(torch.abs(means - test_y.cpu()))))
Test MAE: 0.6400326490402222
[ ]: