[1]:
import math
import torch
import gpytorch
import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

Modifying the Variational Strategy/Variational Distribution

The predictive distribution for approximate GPs is given by

\[p( \mathbf f(\mathbf x^*) ) = \int_{\mathbf u} p( f(\mathbf x^*) \mid \mathbf u) \: q(\mathbf u) \: d\mathbf u, \quad q(\mathbf u) = \mathcal N( \mathbf m, \mathbf S).\]

\(\mathbf u\) represents the function values at the \(m\) inducing points. Here, \(\mathbf m \in \mathbb R^m\) and \(\mathbf S \in \mathbb R^{m \times m}\) are learnable parameters.

If \(m\) (the number of inducing points) is quite large, the number of learnable parameters in \(\mathbf S\) can be quite unwieldy. Furthermore, a large \(m\) might make some of the computations rather slow. Here we show a few ways to use different variational distributions and variational strategies to accomplish this.

Experimental setup

We’re going to train an approximate GP on a medium-sized regression dataset, taken from the UCI repository.

[2]:
import urllib.request
import os
from scipy.io import loadmat
from math import floor


# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)


if not smoke_test and not os.path.isfile('../elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')


if smoke_test:  # this is for running the notebook in our testing framework
    X, y = torch.randn(1000, 3), torch.randn(1000)
else:
    data = torch.Tensor(loadmat('../elevators.mat')['data'])
    X = data[:, :-1]
    X = X - X.min(0)[0]
    X = 2 * (X / X.max(0)[0]) - 1
    y = data[:, -1]


train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()
[3]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=500, shuffle=True)

test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=500, shuffle=False)

Some quick training/testing code

This will allow us to train/test different model classes.

[17]:
# this is for running the notebook in our testing framework
num_epochs = 1 if smoke_test else 10


# Our testing script takes in a GPyTorch MLL (objective function) class
# and then trains/tests an approximate GP with it on the supplied dataset

def train_and_test_approximate_gp(model_cls):
    inducing_points = torch.randn(128, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)
    model = model_cls(inducing_points)
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.numel())
    optimizer = torch.optim.Adam(list(model.parameters()) + list(likelihood.parameters()), lr=0.1)

    if torch.cuda.is_available():
        model = model.cuda()
        likelihood = likelihood.cuda()

    # Training
    model.train()
    likelihood.train()
    epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc=f"Training {model_cls.__name__}")
    for i in epochs_iter:
        # Within each iteration, we will go over each minibatch of data
        for x_batch, y_batch in train_loader:
            optimizer.zero_grad()
            output = model(x_batch)
            loss = -mll(output, y_batch)
            epochs_iter.set_postfix(loss=loss.item())
            loss.backward()
            optimizer.step()

    # Testing
    model.eval()
    likelihood.eval()
    means = torch.tensor([0.])
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            preds = model(x_batch)
            means = torch.cat([means, preds.mean.cpu()])
    means = means[1:]
    error = torch.mean(torch.abs(means - test_y.cpu()))
    print(f"Test {model_cls.__name__} MAE: {error.item()}")

The Standard Approach

As a default, we’ll use the default VariationalStrategy class with a CholeskyVariationalDistribution. The CholeskyVariationalDistribution class allows \(\mathbf S\) to be on any positive semidefinite matrix. This is the most general/expressive option for approximate GPs.

[18]:
class StandardApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(inducing_points.size(-2))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
[19]:
train_and_test_approximate_gp(StandardApproximateGP)

Test StandardApproximateGP MAE: 0.10098349303007126

Reducing parameters

MeanFieldVariationalDistribution: a diagonal \(\mathbf S\) matrix

One way to reduce the number of parameters is to restrict that \(\mathbf S\) is only diagonal. This is less expressive, but the number of parameters is now linear in \(m\) instead of quadratic.

All we have to do is take the previous example, and change CholeskyVariationalDistribution (full \(\mathbf S\) matrix) to MeanFieldVariationalDistribution (diagonal \(\mathbf S\) matrix).

[20]:
class MeanFieldApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(inducing_points.size(-2))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
[21]:
train_and_test_approximate_gp(MeanFieldApproximateGP)

Test MeanFieldApproximateGP MAE: 0.07848489284515381

DeltaVariationalDistribution: no \(\mathbf S\) matrix

A more extreme method of reducing parameters is to get rid of \(\mathbf S\) entirely. This corresponds to learning a delta distribution (\(\mathbf u = \mathbf m\)) rather than a multivariate Normal distribution for \(\mathbf u\). In other words, this corresponds to performing MAP estimation rather than variational inference.

In GPyTorch, getting rid of \(\mathbf S\) can be accomplished by using a DeltaVariationalDistribution.

[22]:
class MAPApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = gpytorch.variational.DeltaVariationalDistribution(inducing_points.size(-2))
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
[26]:
train_and_test_approximate_gp(MAPApproximateGP)

Test MAPApproximateGP MAE: 0.08846496045589447

Reducing computation (through decoupled inducing points)

One way to reduce the computational complexity is to use separate inducing points for the mean and covariance computations. The Orthogonally Decoupled Variational Gaussian Processes method of Salimbeni et al. (2018) uses more inducing points for the (computationally easy) mean computations and fewer inducing points for the (computationally intensive) covariance computations.

In GPyTorch we implement this method in a modular way. The OrthogonallyDecoupledVariationalStrategy defines the variational strategy for the mean inducing points. It wraps an existing variational strategy/distribution that defines the covariance inducing points:

[28]:
def make_orthogonal_vs(model, train_x):
    mean_inducing_points = torch.randn(1000, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)
    covar_inducing_points = torch.randn(100, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)

    covar_variational_strategy = gpytorch.variational.VariationalStrategy(
        model, covar_inducing_points,
        gpytorch.variational.CholeskyVariationalDistribution(covar_inducing_points.size(-2)),
        learn_inducing_locations=True
    )

    variational_strategy = gpytorch.variational.OrthogonallyDecoupledVariationalStrategy(
        covar_variational_strategy, mean_inducing_points,
        gpytorch.variational.DeltaVariationalDistribution(mean_inducing_points.size(-2)),
    )
    return variational_strategy

Putting it all together we have:

[29]:
class OrthDecoupledApproximateGP(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = gpytorch.variational.DeltaVariationalDistribution(inducing_points.size(-2))
        variational_strategy = make_orthogonal_vs(self, train_x)
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
[30]:
train_and_test_approximate_gp(OrthDecoupledApproximateGP)

Test OrthDecoupledApproximateGP MAE: 0.08162340521812439