Exact DKL (Deep Kernel Learning) Regression w/ KISS-GP

Overview

In this notebook, we’ll give a brief tutorial on how to use deep kernel learning for regression on a medium scale dataset using SKI. This also demonstrates how to incorporate standard PyTorch modules in to a Gaussian process model.

[1]:
import math
import tqdm
import torch
import gpytorch
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline

Loading Data

For this example notebook, we’ll be using the elevators UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we’ll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.

Note: Running the next cell will attempt to download a ~400 KB dataset file to the current directory.

[2]:
import urllib.request
import os
from scipy.io import loadmat
from math import floor


# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)


if not smoke_test and not os.path.isfile('../elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')


if smoke_test:  # this is for running the notebook in our testing framework
    X, y = torch.randn(2000, 3), torch.randn(2000)
else:
    data = torch.Tensor(loadmat('../elevators.mat')['data'])
    X = data[:, :-1]
    X = X - X.min(0)[0]
    X = 2 * (X / X.max(0)[0]) - 1
    y = data[:, -1]


train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()

Defining the DKL Feature Extractor

Next, we define the neural network feature extractor used to define the deep kernel. In this case, we use a fully connected network with the architecture d -> 1000 -> 500 -> 50 -> 2, as described in the original DKL paper. All of the code below uses standard PyTorch implementations of neural network layers.

[3]:
data_dim = train_x.size(-1)

class LargeFeatureExtractor(torch.nn.Sequential):
    def __init__(self):
        super(LargeFeatureExtractor, self).__init__()
        self.add_module('linear1', torch.nn.Linear(data_dim, 1000))
        self.add_module('relu1', torch.nn.ReLU())
        self.add_module('linear2', torch.nn.Linear(1000, 500))
        self.add_module('relu2', torch.nn.ReLU())
        self.add_module('linear3', torch.nn.Linear(500, 50))
        self.add_module('relu3', torch.nn.ReLU())
        self.add_module('linear4', torch.nn.Linear(50, 2))

feature_extractor = LargeFeatureExtractor()

Defining the DKL-GP Model

We now define the GP model. For more details on the use of GP models, see our simpler examples. This model uses a GridInterpolationKernel (SKI) with an RBF base kernel.

The forward method

In deep kernel learning, the forward method is where most of the interesting new stuff happens. Before calling the mean and covariance modules on the data as in the simple GP regression setting, we first pass the input data x through the neural network feature extractor. Then, to ensure that the output features of the neural network remain in the grid bounds expected by SKI, we scales the resulting features to be between 0 and 1.

Only after this processing do we call the mean and covariance module of the Gaussian process. This example also demonstrates the flexibility of defining GP models that allow for learned transformations of the data (in this case, via a neural network) before calling the mean and covariance function. Because the neural network in this case maps to two final output features, we will have no problem using SKI.

[4]:
class GPRegressionModel(gpytorch.models.ExactGP):
        def __init__(self, train_x, train_y, likelihood):
            super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
            self.mean_module = gpytorch.means.ConstantMean()
            self.covar_module = gpytorch.kernels.GridInterpolationKernel(
                gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=2)),
                num_dims=2, grid_size=100
            )
            self.feature_extractor = feature_extractor

            # This module will scale the NN features so that they're nice values
            self.scale_to_bounds = gpytorch.utils.grid.ScaleToBounds(-1., 1.)

        def forward(self, x):
            # We're first putting our data through a deep net (feature extractor)
            projected_x = self.feature_extractor(x)
            projected_x = self.scale_to_bounds(projected_x)  # Make the NN values "nice"

            mean_x = self.mean_module(projected_x)
            covar_x = self.covar_module(projected_x)
            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
[5]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GPRegressionModel(train_x, train_y, likelihood)

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

Training the model

The cell below trains the DKL model above, learning both the hyperparameters of the Gaussian process and the parameters of the neural network in an end-to-end fashion using Type-II MLE. We run 20 iterations of training using the Adam optimizer built in to PyTorch. With a decent GPU, this should only take a few seconds.

[6]:
training_iterations = 2 if smoke_test else 60

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam([
    {'params': model.feature_extractor.parameters()},
    {'params': model.covar_module.parameters()},
    {'params': model.mean_module.parameters()},
    {'params': model.likelihood.parameters()},
], lr=0.01)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

def train():
    iterator = tqdm.notebook.tqdm(range(training_iterations))
    for i in iterator:
        # Zero backprop gradients
        optimizer.zero_grad()
        # Get output from model
        output = model(train_x)
        # Calc loss and backprop derivatives
        loss = -mll(output, train_y)
        loss.backward()
        iterator.set_postfix(loss=loss.item())
        optimizer.step()

%time train()

CPU times: user 45.6 s, sys: 4.29 s, total: 49.8 s
Wall time: 13.7 s

Making Predictions

The next cell gets the predictive covariance for the test set (and also technically gets the predictive mean, stored in preds.mean()) using the standard SKI testing code, with no acceleration or precomputation.

[7]:
model.eval()
likelihood.eval()
with torch.no_grad(), gpytorch.settings.use_toeplitz(False), gpytorch.settings.fast_pred_var():
    preds = model(test_x)
[8]:
print('Test MAE: {}'.format(torch.mean(torch.abs(preds.mean - test_y))))
Test MAE: 0.07007455825805664
[ ]: