GPyTorch Regression With KeOps

Introduction

KeOps is a recently released software package for fast kernel operations that integrates wih PyTorch. We can use the ability of KeOps to perform efficient kernel matrix multiplies on the GPU to integrate with the rest of GPyTorch.

In this tutorial, we’ll demonstrate how to integrate the kernel matmuls of KeOps with all of the bells of whistles of GPyTorch, including things like our preconditioning for conjugate gradients.

In this notebook, we will train an exact GP on 3droad, which has hundreds of thousands of data points. Together, the highly optimized matmuls of KeOps combined with algorithmic speed improvements like preconditioning allow us to train on a dataset like this in a matter of minutes using only a single GPU.

[1]:
import math
import torch
import gpytorch
import tqdm.notebook as tqdm
from matplotlib import pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

Downloading Data

We will be using the 3droad UCI dataset which contains a total of 434,874 data points. We will split the data in half for training and half for testing.

The next cell will download this dataset from a Google drive and load it.

[2]:
import urllib.request
import os.path
from scipy.io import loadmat
from math import floor

if not os.path.isfile('../3droad.mat'):
    print('Downloading \'3droad\' UCI dataset...')
    urllib.request.urlretrieve('https://www.dropbox.com/s/f6ow1i59oqx05pl/3droad.mat?dl=1', '../3droad.mat')

data = torch.Tensor(loadmat('../3droad.mat')['data'])
[3]:
import numpy as np

N = data.shape[0]
# make train/val/test
n_train = int(0.5 * N)

train_x, train_y = data[:n_train, :-1], data[:n_train, -1]
test_x, test_y = data[n_train:, :-1], data[n_train:, -1]

# normalize features
mean = train_x.mean(dim=-2, keepdim=True)
std = train_x.std(dim=-2, keepdim=True) + 1e-6 # prevent dividing by 0
train_x = (train_x - mean) / std
test_x = (test_x - mean) / std

# normalize labels
mean, std = train_y.mean(),train_y.std()
train_y = (train_y - mean) / std
test_y = (test_y - mean) / std

# make continguous
train_x, train_y = train_x.contiguous(), train_y.contiguous()
test_x, test_y = test_x.contiguous(), test_y.contiguous()

output_device = torch.device('cuda:0')

train_x, train_y = train_x.to(output_device), train_y.to(output_device)
test_x, test_y = test_x.to(output_device), test_y.to(output_device)

print(
    f"Num train: {train_y.size(-1)}\n"
    f"Num test: {test_y.size(-1)}"
)
Num train: 217437
Num test: 217437

Using KeOps with a GPyTorch Model

Using KeOps with one of our pre built kernels is as straightforward as swapping the kernel out. For example, in the cell below, we copy the simple GP from our basic tutorial notebook, and swap out gpytorch.kernels.MaternKernel for gpytorch.kernels.keops.MaternKernel.

[4]:
# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()

        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.keops.MaternKernel(nu=2.5))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()
model = ExactGPModel(train_x, train_y, likelihood).cuda()

# Because we know some properties about this dataset,
# we will initialize the lengthscale to be somewhat small
# This step isn't necessary, but it will help the model converge faster.
model.covar_module.base_kernel.lengthscale = 0.05
[5]:
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)

import time
training_iter = 25
iterator = tqdm.tqdm(range(training_iter), desc="Training")
for i in iterator:
    start_time = time.time()
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(train_x)
    # Calc loss and backprop gradients
    loss = -mll(output, train_y)
    print_values = dict(
        loss=loss.item(),
        ls=model.covar_module.base_kernel.lengthscale.norm().item(),
        os=model.covar_module.outputscale.item(),
        noise=model.likelihood.noise.item(),
        mu=model.mean_module.constant.item(),
    )
    iterator.set_postfix(**print_values)
    loss.backward()
    optimizer.step()
[6]:
# Get into evaluation (predictive posterior) mode
model.eval()

# Test points are regularly spaced along [0,1]
# Make predictions by feeding model through likelihood
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    observed_pred = model.likelihood(model(test_x))

Compute RMSE

[7]:
rmse = (observed_pred.mean - test_y).square().mean().sqrt().item()
print(f"RMSE: {rmse:.3f}")
RMSE: 0.138