Scalable Kernel Interpolation for Product Kernels (SKIP)

Overview

In this notebook, we’ll overview of how to use SKIP, a method that exploits product structure in some kernels to reduce the dependency of SKI on the data dimensionality from exponential to linear.

The most important practical consideration to note in this notebook is the use of gpytorch.settings.max_root_decomposition_size, which we explain the use of right before the training loop cell.

[1]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline

For this example notebook, we’ll be using the elevators UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we’ll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.

Note: Running the next cell will attempt to download a ~400 KB dataset file to the current directory.

[2]:
import urllib.request
import os
from scipy.io import loadmat
from math import floor


# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)


if not smoke_test and not os.path.isfile('../elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')


if smoke_test:  # this is for running the notebook in our testing framework
    X, y = torch.randn(1000, 3), torch.randn(1000)
else:
    data = torch.Tensor(loadmat('../elevators.mat')['data'])
    X = data[:, :-1]
    X = X - X.min(0)[0]
    X = 2 * (X / X.max(0)[0]) - 1
    y = data[:, -1]


train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()
[3]:
X.size()
[3]:
torch.Size([16599, 18])

Defining the SKIP GP Model

We now define the GP model. For more details on the use of GP models, see our simpler examples. This model uses a GridInterpolationKernel (SKI) with an RBF base kernel. To use SKIP, we make two changes:

  • First, we define our base_covar_module to have a batch_shape equal to the dimensionality of the data. We make this change because we will the base_covar_module to construct a batch of univariate kernels which we will then multiply using SKIP.

  • We use only a 1 dimensional GridInterpolationKernel (e.g., by passing num_dims=1). The idea of SKIP is to use a product of 1 dimensional GridInterpolationKernels instead of a single d dimensional one.

  • In the forward call, we reshape x to be d x n x 1 before passing it through the covar_module. Our covar_module produces a batch of univariate kernels, and x must treat each dimension as a batch.

  • After constructing our univariate covariance matrices, we multiply them all together by calling .prod(dim=-3).

For more details on this construction, see the Kernels with Additive or Product Structure tutorial.

[4]:
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel, GridInterpolationKernel
from gpytorch.distributions import MultivariateNormal

class GPRegressionModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean()
        self.base_covar_module = RBFKernel(batch_shape=torch.Size([train_x.size(-1)]))
        self.covar_module = ScaleKernel(
            GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1)
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        univariate_covars = self.covar_module(x.mT.unsqueeze(-1))
        covar_x = univariate_covars.prod(dim=-3)
        return MultivariateNormal(mean_x, covar_x)
[5]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GPRegressionModel(train_x, train_y, likelihood)

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

Training the model

The training loop for SKIP has one main new feature we haven’t seen before: we specify the max_root_decomposition_size. This controls how many iterations of Lanczos we want to use for SKIP, and trades off with time and–more importantly–space. Realistically, the goal should be to set this as high as possible without running out of memory.

In some sense, this parameter is the main trade-off of SKIP. Whereas many inducing point methods care more about the number of inducing points, because SKIP approximates one dimensional kernels, it is able to do so very well with relatively few inducing points. The main source of approximation really comes from these Lanczos decompositions we perform.

[6]:
training_iterations = 2 if smoke_test else 50

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

def train():
    for i in range(training_iterations):
        # Zero backprop gradients
        optimizer.zero_grad()
        with gpytorch.settings.use_toeplitz(False), gpytorch.settings.max_root_decomposition_size(30):
            # Get output from model
            output = model(train_x)
            # Calc loss and backprop derivatives
            loss = -mll(output, train_y)
            loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
        optimizer.step()
        torch.cuda.empty_cache()

%time train()
/home/gpleiss/workspace/linear_operator/linear_operator/utils/sparse.py:51: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if nonzero_indices.storage():
/home/gpleiss/workspace/linear_operator/linear_operator/utils/sparse.py:66: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:78.)
  res = cls(index_tensor, value_tensor, interp_size)
/home/gpleiss/workspace/linear_operator/linear_operator/utils/sparse.py:66: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated.  Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:621.)
  res = cls(index_tensor, value_tensor, interp_size)
Iter 1/50 - Loss: 0.782
Iter 2/50 - Loss: 0.767
Iter 3/50 - Loss: 0.749
Iter 4/50 - Loss: 0.733
Iter 5/50 - Loss: 0.717
Iter 6/50 - Loss: 0.699
Iter 7/50 - Loss: 0.682
Iter 8/50 - Loss: 0.665
Iter 9/50 - Loss: 0.648
Iter 10/50 - Loss: 0.631
Iter 11/50 - Loss: 0.613
Iter 12/50 - Loss: 0.596
Iter 13/50 - Loss: 0.578
Iter 14/50 - Loss: 0.561
Iter 15/50 - Loss: 0.544
Iter 16/50 - Loss: 0.526
Iter 17/50 - Loss: 0.509
Iter 18/50 - Loss: 0.491
Iter 19/50 - Loss: 0.474
Iter 20/50 - Loss: 0.457
Iter 21/50 - Loss: 0.439
Iter 22/50 - Loss: 0.422
Iter 23/50 - Loss: 0.405
Iter 24/50 - Loss: 0.388
Iter 25/50 - Loss: 0.372
Iter 26/50 - Loss: 0.355
Iter 27/50 - Loss: 0.339
Iter 28/50 - Loss: 0.322
Iter 29/50 - Loss: 0.306
Iter 30/50 - Loss: 0.291
Iter 31/50 - Loss: 0.276
Iter 32/50 - Loss: 0.261
Iter 33/50 - Loss: 0.246
Iter 34/50 - Loss: 0.232
Iter 35/50 - Loss: 0.218
Iter 36/50 - Loss: 0.204
Iter 37/50 - Loss: 0.191
Iter 38/50 - Loss: 0.179
Iter 39/50 - Loss: 0.167
Iter 40/50 - Loss: 0.155
Iter 41/50 - Loss: 0.144
Iter 42/50 - Loss: 0.134
Iter 43/50 - Loss: 0.124
Iter 44/50 - Loss: 0.114
Iter 45/50 - Loss: 0.106
Iter 46/50 - Loss: 0.097
Iter 47/50 - Loss: 0.089
Iter 48/50 - Loss: 0.082
Iter 49/50 - Loss: 0.075
Iter 50/50 - Loss: 0.068
CPU times: user 53.2 s, sys: 3.96 s, total: 57.1 s
Wall time: 1min 16s

Making Predictions

The next cell makes predictions with SKIP. We use the same max_root_decomposition size, and we also demonstrate increasing the max preconditioner size. Increasing the preconditioner size on this dataset is not necessary, but can make a big difference in final test performance, and is often preferable to increasing the number of CG iterations if you can afford the space.

[7]:
model.eval()
likelihood.eval()
with gpytorch.settings.max_preconditioner_size(10), torch.no_grad():
    with gpytorch.settings.max_root_decomposition_size(30), gpytorch.settings.fast_pred_var():
        preds = model(test_x)
[8]:
print('Test MAE: {}'.format(torch.mean(torch.abs(preds.mean - test_y))))
Test MAE: 0.18244513869285583