VNNGP: Variational Nearest Neighbor Gaussian Procceses

Overview

In this notebook, we will give an overview of how to use variational nearest neighbor Gaussian processes (VNNGP) (https://arxiv.org/abs/2202.01694) to rapidly train on the elevators UCI dataset.

Similar to SVGP (https://arxiv.org/abs/1309.6835), VNNGP is a variational inducing point-based approach. Unlike SVGP that is typically limited to thousands of inducing points, VNNGP makes an additional approximation: it assumes that every inducing point and data point only depends on \(\leq K\) other inducing points. This is advantageous for multiple reasons: - The variational KL divergence term affords an unbiased stochastic estimate from a minibatch of inducing points - Consequentially, an unbiased estimate of the ELBO can be computed in \(O(K^3)\) time.

With this scalability, we recommend using \(M=N\) inducing points, placing inducing points at every observed input.

[1]:
import tqdm
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline
[2]:
import urllib.request
import os
from scipy.io import loadmat
from math import floor


# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)

if not smoke_test and not os.path.isfile('../elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')


if smoke_test:  # this is for running the notebook in our testing framework
    X, y = torch.randn(100, 3), torch.randn(100)
else:
    data = torch.Tensor(loadmat('../elevators.mat')['data'])
    X = data[:1000, :-1]
    X = X - X.min(0)[0]
    X = 2 * (X / X.max(0)[0].clamp_min(1e-6)) - 1
    y = data[:1000, -1]
    y = y.sub(y.mean()).div(y.std())


train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()

Creating a VNNGP model

Creating a VNNGP model is similar to creating other variational models. Here are some things that are specific to VNNGP:

  1. You need to use the gpytorch.variational.NNVariationalStrategy variational strategy class.
  2. You need to use the gpytorch.variational.MeanFieldVariationalDistribution variational distribution class.
  3. inducing_points should be set to train_x. This results in fast optimization and accurate predictions.
  4. There are two hyperparameters that you need to specify:
  • k: number of nearest neighbors used. The higher the k is, the better the approximation accuracy is, but also more computations are needed. Default value is 256.
  • training_batch_size: the mini-batch size of inducing points used in stochastic optimization. Default value is 256.
[3]:
from gpytorch.models import ApproximateGP
from gpytorch.variational.nearest_neighbor_variational_strategy import NNVariationalStrategy


class GPModel(ApproximateGP):
    def __init__(self, inducing_points, likelihood, k=256, training_batch_size=256):

        m, d = inducing_points.shape
        self.m = m
        self.k = k

        variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(m)

        if torch.cuda.is_available():
            inducing_points = inducing_points.cuda()

        variational_strategy = NNVariationalStrategy(self, inducing_points, variational_distribution, k=k,
                                                     training_batch_size=training_batch_size)
        super(GPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=d)

        self.likelihood = likelihood

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    def __call__(self, x, prior=False, **kwargs):
        if x is not None:
            if x.dim() == 1:
                x = x.unsqueeze(-1)
        return self.variational_strategy(x=x, prior=False, **kwargs)

if smoke_test:
    k = 32
    training_batch_size = 32
else:
    k = 256
    training_batch_size = 64

likelihood = gpytorch.likelihoods.GaussianLikelihood()
# Note: one should use full training set as inducing points!
model = GPModel(inducing_points=train_x, likelihood=likelihood, k=k, training_batch_size=training_batch_size)

if torch.cuda.is_available():
    likelihood = likelihood.cuda()
    model = model.cuda()

Training the Model (mode 1, recommended)

The cell below trains the model above, learning both the hyperparameters of the Gaussian process and the parameters of the neural network in an end-to-end fashion using Type-II MLE. VNNGP’s training objective is to miminize the negative ELBO. Note that in the beginning, we introduce that VNNGP is ameanable to stochastic optimization over both data points and inducing points. That means in each iteration, we will sample a mini-batch of data points and a mini-batch of inducing points, compute the stochastic ELBO estimate, and then take a gradient step to update the model parameters.

There are two training modes available. In this section we will introduce the mode 1, which is what we recommended in practice and what is implemented in experiments of the original paper. Since VNNGP sets inducing point locations to observed input locations, inducing points are essentially the train_x. Therefore, there is no need to separately iterate over training data and inducing points. As a result, we could just sample a mini-batch of inducing points, which would be treated as a mini-batch of training data as well. In this case, the mini-batch size for training data is the same as that for inducing points, which is training_batch_size we set above.

While we recommend this training mode as it yields faster training, we do provide another training mode that allows users to use different mini-batches of training data and inducing points. See the last part of the notebook.

[4]:
num_epochs = 1 if smoke_test else 20
num_batches = model.variational_strategy._total_training_batches


model.train()
likelihood.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Our loss object. We're using the VariationalELBO
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))


epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc="Epoch")
for epoch in epochs_iter:
    minibatch_iter = tqdm.notebook.tqdm(range(num_batches), desc="Minibatch", leave=False)

    for i in minibatch_iter:
        optimizer.zero_grad()
        output = model(x=None)
        # Obtain the indices for mini-batch data
        current_training_indices = model.variational_strategy.current_training_indices
        # Obtain the y_batch using indices. It is important to keep the same order of train_x and train_y
        y_batch = train_y[...,current_training_indices]
        if torch.cuda.is_available():
            y_batch = y_batch.cuda()
        loss = -mll(output, y_batch)
        minibatch_iter.set_postfix(loss=loss.item())
        loss.backward()
        optimizer.step()


Making predictions

[5]:
from torch.utils.data import TensorDataset, DataLoader


test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
[6]:
model.eval()
likelihood.eval()
means = torch.tensor([0.])
test_mse = 0
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        preds = model(x_batch)
        means = torch.cat([means, preds.mean.cpu()])

        diff = torch.pow(preds.mean - y_batch, 2)
        diff = diff.sum(dim=-1) / test_x.size(0) # sum over bsz and scaling
        diff = diff.mean() # average over likelihood_nsamples
        test_mse += diff
means = means[1:]
test_rmse = test_mse.sqrt().item()
[7]:
from torch.utils.data import TensorDataset, DataLoader


test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
[8]:
model.eval()
likelihood.eval()
means = torch.tensor([0.])
test_mse = 0
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        preds = model(x_batch)
        means = torch.cat([means, preds.mean.cpu()])

        diff = torch.pow(preds.mean - y_batch, 2)
        diff = diff.sum(dim=-1) / test_x.size(0) # sum over bsz and scaling
        diff = diff.mean() # average over likelihood_nsamples
        test_mse += diff
means = means[1:]
test_rmse = test_mse.sqrt().item()
[9]:
print(test_rmse)
0.8139828443527222

Training the Model (mode 2)

In this mode, users are able to sample separate mini-batches for training data and inducing points. Note that this will yield a slower training speed, since every iteration requires finding the set of inducing points that matches the current batch of training data for calculating ELBO.

[10]:
# instantiate the model
if smoke_test:
    k = 32
    training_batch_size = 32
else:
    k = 256
    training_batch_size = 256
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GPModel(inducing_points=train_x, likelihood=likelihood, k=k, training_batch_size=training_batch_size)

if torch.cuda.is_available():
    likelihood = likelihood.cuda()
    model = model.cuda()
[11]:
# prepare for dataset
train_dataset = TensorDataset(train_x, train_y)
# this batch-size does not need to match the training-batch-size specified above
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

num_epochs = 1 if smoke_test else 20


model.train()
likelihood.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Our loss object. We're using the VariationalELBO
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))


epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc="Epoch")
for i in epochs_iter:
    # Within each iteration, we will go over each minibatch of data
    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc="Minibatch", leave=False)
    for x_batch, y_batch in minibatch_iter:
        optimizer.zero_grad()
        output = model(x_batch)
        loss = -mll(output, y_batch)
        minibatch_iter.set_postfix(loss=loss.item())
        loss.backward()
        optimizer.step()

Making predictions

[12]:
model.eval()
likelihood.eval()
means = torch.tensor([0.])
test_mse = 0
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        preds = model(x_batch)
        means = torch.cat([means, preds.mean.cpu()])

        diff = torch.pow(preds.mean - y_batch, 2)
        diff = diff.sum(dim=-1) / test_x.size(0) # sum over bsz and scaling
        diff = diff.mean() # average over likelihood_nsamples
        test_mse += diff
means = means[1:]
test_rmse = test_mse.sqrt().item()
[13]:
print(test_rmse)
0.8090996742248535
[ ]: