Scalable Exact GP Posterior Sampling using Contour Integral Quadrature¶
This notebook demonstrates the most simple usage of contour integral quadrature with msMINRES as described here to sample from the predictive distribution of an exact GP.
Note that to achieve results where Cholesky would run the GPU out of memory, you’ll need to have KeOps installed (see our KeOps tutorial in this same folder). Despite this, on this relatively simple example with 1000 training points but seeing to sample at 20000 test points in 1D, we will achieve significant speed ups over Cholesky.
# Training data is 11 points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, 1000)
# True function is sin(2*pi*x) with Gaussian noise
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2
Are we running with KeOps?¶
If you have KeOps, change the below flag to True
to run with a significantly larger test set.
Define an Exact GP Model and train¶
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.keops.RBFKernel())
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)
if torch.cuda.is_available():
train_x = train_x.cuda()
train_y = train_y.cuda()
model = model.cuda()
likelihood = likelihood.cuda()
# this is for running the notebook in our testing framework
# Find optimal model hyperparameters
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iter):
# Zero gradients from previous iteration
# Output from model
output = model(train_x)
# Calc loss and backprop gradients
loss = -mll(output, train_y)
print('Iter %d/%d - Loss: %.3f lengthscale: %.3f noise: %.3f' % (
i + 1, training_iter, loss.item(),
Define test set¶
If we have KeOps installed, we’ll test on 50000 points instead of 10000.
test_n = 50000
test_n = 10000
test_x = torch.linspace(0, 1, test_n)
if torch.cuda.is_available():
test_x = test_x.cuda()
Draw a sample with CIQ¶
To do this, we just add the ciq_samples
setting to the rsample call. We additionally demonstrate all relevant settings for controlling Contour Integral Quadrature:
setting determines whether or not to use CIQThe
setting controls the number of quadrature sites (Q in the paper).The
setting controls the error we tolerate from minres (here, <0.01%).
Note that, of these settings, increase num_contour_quadrature is unlikely to improve performance. As Theorem 1 from the paper demonstrates, virtually all of the error in this method is controlled by minres_tolerance. Here, we use a quite tight tolerance for minres.
import time
# Get into evaluation (predictive posterior) mode
# Test points are regularly spaced along [0,1]
# Make predictions by feeding model through likelihood
with torch.no_grad():
observed_pred = likelihood(model(test_x))
# All relevant settings for using CIQ.
# ciq_samples(True) - Use CIQ for sampling
# num_contour_quadrature(10) -- Use 10 quadrature sites (Q in the paper)
# minres_tolerance -- error tolerance from minres (here, <0.01%).
print("Running with CIQ")
with gpytorch.settings.ciq_samples(True), gpytorch.settings.num_contour_quadrature(10), gpytorch.settings.minres_tolerance(1e-4):
%time y_samples = observed_pred.rsample()
print("Running with Cholesky")
# Make sure we use Cholesky
with gpytorch.settings.fast_computations(covar_root_decomposition=False):
%time y_samples = observed_pred.rsample()
Running with CIQ
CPU times: user 711 ms, sys: 38.4 ms, total: 749 ms
Wall time: 677 ms
Running with Cholesky
CPU times: user 1min 36s, sys: 892 ms, total: 1min 37s
Wall time: 30.9 s
