Scalable Exact GP Posterior Sampling using Contour Integral Quadrature

This notebook demonstrates the most simple usage of contour integral quadrature with msMINRES as described here to sample from the predictive distribution of an exact GP.

Note that to achieve results where Cholesky would run the GPU out of memory, you’ll either need to have KeOps installed (see our KeOps tutorial in this same folder), or use the checkpoint_kernel beta feature. Despite this, on this relatively simple example with 1000 training points but seeing to sample at 20000 test points in 1D, we will achieve significant speed ups over Cholesky.

[1]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

import warnings
warnings.simplefilter("ignore", gpytorch.utils.warnings.NumericalWarning)

%matplotlib inline
%load_ext autoreload
%autoreload 2
[2]:
# Training data is 11 points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, 1000)
# True function is sin(2*pi*x) with Gaussian noise
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2

Are we running with KeOps?

If you have KeOps, change the below flag to True to run with a significantly larger test set.

[3]:
HAVE_KEOPS = False

Define an Exact GP Model and train

[4]:
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()

        if HAVE_KEOPS:
            self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.keops.RBFKernel())
        else:
            self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)
[5]:
if torch.cuda.is_available():
    train_x = train_x.cuda()
    train_y = train_y.cuda()
    model = model.cuda()
    likelihood = likelihood.cuda()
[6]:
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iter = 2 if smoke_test else 50

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iter):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(train_x)
    # Calc loss and backprop gradients
    loss = -mll(output, train_y)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
        i + 1, training_iter, loss.item(),
        model.covar_module.base_kernel.lengthscale.item(),
        model.likelihood.noise.item()
    ))
    optimizer.step()
Iter 1/50 - Loss: 0.861   lengthscale: 0.693   noise: 0.693
Iter 2/50 - Loss: 0.814   lengthscale: 0.644   noise: 0.644
Iter 3/50 - Loss: 0.763   lengthscale: 0.598   noise: 0.598
Iter 4/50 - Loss: 0.718   lengthscale: 0.554   noise: 0.554
Iter 5/50 - Loss: 0.666   lengthscale: 0.513   noise: 0.513
Iter 6/50 - Loss: 0.618   lengthscale: 0.474   noise: 0.474
Iter 7/50 - Loss: 0.572   lengthscale: 0.439   noise: 0.437
Iter 8/50 - Loss: 0.530   lengthscale: 0.408   noise: 0.402
Iter 9/50 - Loss: 0.486   lengthscale: 0.380   noise: 0.370
Iter 10/50 - Loss: 0.452   lengthscale: 0.355   noise: 0.339
Iter 11/50 - Loss: 0.415   lengthscale: 0.334   noise: 0.311
Iter 12/50 - Loss: 0.376   lengthscale: 0.316   noise: 0.285
Iter 13/50 - Loss: 0.331   lengthscale: 0.301   noise: 0.261
Iter 14/50 - Loss: 0.293   lengthscale: 0.288   noise: 0.238
Iter 15/50 - Loss: 0.258   lengthscale: 0.276   noise: 0.217
Iter 16/50 - Loss: 0.219   lengthscale: 0.266   noise: 0.198
Iter 17/50 - Loss: 0.188   lengthscale: 0.258   noise: 0.181
Iter 18/50 - Loss: 0.146   lengthscale: 0.250   noise: 0.165
Iter 19/50 - Loss: 0.109   lengthscale: 0.244   noise: 0.150
Iter 20/50 - Loss: 0.081   lengthscale: 0.238   noise: 0.136
Iter 21/50 - Loss: 0.042   lengthscale: 0.234   noise: 0.124
Iter 22/50 - Loss: 0.007   lengthscale: 0.230   noise: 0.113
Iter 23/50 - Loss: -0.020   lengthscale: 0.227   noise: 0.103
Iter 24/50 - Loss: -0.043   lengthscale: 0.224   noise: 0.094
Iter 25/50 - Loss: -0.075   lengthscale: 0.222   noise: 0.085
Iter 26/50 - Loss: -0.094   lengthscale: 0.221   noise: 0.078
Iter 27/50 - Loss: -0.108   lengthscale: 0.221   noise: 0.071
Iter 28/50 - Loss: -0.149   lengthscale: 0.221   noise: 0.065
Iter 29/50 - Loss: -0.160   lengthscale: 0.221   noise: 0.060
Iter 30/50 - Loss: -0.174   lengthscale: 0.222   noise: 0.055
Iter 31/50 - Loss: -0.191   lengthscale: 0.223   noise: 0.050
Iter 32/50 - Loss: -0.205   lengthscale: 0.224   noise: 0.047
Iter 33/50 - Loss: -0.205   lengthscale: 0.226   noise: 0.043
Iter 34/50 - Loss: -0.211   lengthscale: 0.228   noise: 0.040
Iter 35/50 - Loss: -0.211   lengthscale: 0.230   noise: 0.038
Iter 36/50 - Loss: -0.208   lengthscale: 0.232   noise: 0.035
Iter 37/50 - Loss: -0.219   lengthscale: 0.235   noise: 0.034
Iter 38/50 - Loss: -0.199   lengthscale: 0.237   noise: 0.032
Iter 39/50 - Loss: -0.198   lengthscale: 0.240   noise: 0.031
Iter 40/50 - Loss: -0.201   lengthscale: 0.243   noise: 0.030
Iter 41/50 - Loss: -0.204   lengthscale: 0.246   noise: 0.029
Iter 42/50 - Loss: -0.196   lengthscale: 0.249   noise: 0.028
Iter 43/50 - Loss: -0.196   lengthscale: 0.252   noise: 0.028
Iter 44/50 - Loss: -0.197   lengthscale: 0.254   noise: 0.028
Iter 45/50 - Loss: -0.183   lengthscale: 0.256   noise: 0.028
Iter 46/50 - Loss: -0.183   lengthscale: 0.258   noise: 0.028
Iter 47/50 - Loss: -0.195   lengthscale: 0.261   noise: 0.028
Iter 48/50 - Loss: -0.197   lengthscale: 0.263   noise: 0.029
Iter 49/50 - Loss: -0.193   lengthscale: 0.265   noise: 0.029
Iter 50/50 - Loss: -0.203   lengthscale: 0.268   noise: 0.030

Define test set

If we have KeOps installed, we’ll test on 50000 points instead of 10000.

[7]:
if HAVE_KEOPS:
    test_n = 50000
else:
    test_n = 10000

test_x = torch.linspace(0, 1, test_n)
if torch.cuda.is_available():
    test_x = test_x.cuda()
print(test_x.shape)
torch.Size([10000])

Draw a sample with CIQ

To do this, we just add the ciq_samples setting to the rsample call. We additionally demonstrate all relevant settings for controlling Contour Integral Quadrature:

  • The ciq_samples setting determines whether or not to use CIQ
  • The num_contour_quadrature setting controls the number of quadrature sites (Q in the paper).
  • The minres_tolerance setting controls the error we tolerate from minres (here, <0.01%).

Note that, of these settings, increase num_contour_quadrature is unlikely to improve performance. As Theorem 1 from the paper demonstrates, virtually all of the error in this method is controlled by minres_tolerance. Here, we use a quite tight tolerance for minres.

[8]:
import time

model.train()
likelihood.train()

# Get into evaluation (predictive posterior) mode
model.eval()
likelihood.eval()

# Test points are regularly spaced along [0,1]
# Make predictions by feeding model through likelihood

test_x.requires_grad_(True)

with torch.no_grad():
    observed_pred = likelihood(model(test_x))

    # All relevant settings for using CIQ.
    #   ciq_samples(True) - Use CIQ for sampling
    #   num_contour_quadrature(10) -- Use 10 quadrature sites (Q in the paper)
    #   minres_tolerance -- error tolerance from minres (here, <0.01%).
    print("Running with CIQ")
    with gpytorch.settings.ciq_samples(True), gpytorch.settings.num_contour_quadrature(10), gpytorch.settings.minres_tolerance(1e-4):
        %time y_samples = observed_pred.rsample()

    print("Running with Cholesky")
    # Make sure we use Cholesky
    with gpytorch.settings.fast_computations(covar_root_decomposition=False):
        %time y_samples = observed_pred.rsample()
Running with CIQ
CPU times: user 711 ms, sys: 38.4 ms, total: 749 ms
Wall time: 677 ms
Running with Cholesky
CPU times: user 1min 36s, sys: 892 ms, total: 1min 37s
Wall time: 30.9 s
[ ]: