Variational GPs w/ Multiple Outputs

Introduction

In this example, we will demonstrate how to construct approximate/variational GPs that can model vector-valued functions (e.g. multitask/multi-output GPs).

[1]:
import math
import torch
import gpytorch
import tqdm
from matplotlib import pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

Set up training data

In the next cell, we set up the training data for this example. We’ll be using 100 regularly spaced points on [0,1] which we evaluate the function on and add Gaussian noise to get the training labels.

We’ll have four functions - all of which are some sort of sinusoid. Our train_targets will actually have two dimensions: with the second dimension corresponding to the different tasks.

[2]:
train_x = torch.linspace(0, 1, 100)

train_y = torch.stack([
    torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    torch.sin(train_x * (2 * math.pi)) + 2 * torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    -torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
], -1)

print(train_x.shape, train_y.shape)
torch.Size([100]) torch.Size([100, 4])

Define a multitask model

We are going to construct a batch variational GP - using a CholeskyVariationalDistribution and a VariationalStrategy. Each of the batch dimensions is going to correspond to one of the outputs. In addition, we will wrap the variational strategy to make the output appear as a MultitaskMultivariateNormal distribution. Here are the changes that we’ll need to make:

  1. Our inducing points will need to have shape 2 x m x 1 (where m is the number of inducing points). This ensures that we learn a different set of inducing points for each output dimension.
  2. The CholeskyVariationalDistribution, mean module, and covariance modules will all need to include a batch_shape=torch.Size([2]) argument. This ensures that we learn a different set of variational parameters and hyperparameters for each output dimension.
  3. The VariationalStrategy object should be wrapped by a variational strategy that handles multitask models. We describe them below:

Types of Variational Multitask Models

The most general purpose multitask model is the Linear Model of Coregionalization (LMC), which assumes that each output dimension (task) is the linear combination of some latent functions \(\mathbf g(\cdot) = [g^{(1)}(\cdot), \ldots, g^{(Q)}(\cdot)]\):

\[f_\text{task}(\mathbf x) = \sum_{i=1}^Q a^{(i)} g^{(i)}(\mathbf x),\]

where \(a^{(i)}\) are learnable parameters.

[3]:
num_latents = 3
num_tasks = 4

class MultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self):
        # Let's use a different set of inducing points for each latent function
        inducing_points = torch.rand(num_latents, 16, 1)

        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_latents])
        )

        # We have to wrap the VariationalStrategy in a LMCVariationalStrategy
        # so that the output will be a MultitaskMultivariateNormal rather than a batch output
        variational_strategy = gpytorch.variational.LMCVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=4,
            num_latents=3,
            latent_dim=-1
        )

        super().__init__(variational_strategy)

        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
            batch_shape=torch.Size([num_latents])
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


model = MultitaskGPModel()
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_tasks)

With all of the batch_shape arguments - it may look like we’re learning a batch of GPs. However, LMCVariationalStrategy objects convert this batch_dimension into a (non-batch) MultitaskMultivariateNormal.

[4]:
likelihood(model(train_x)).rsample().shape
[4]:
torch.Size([100, 4])

The LMC model allows there to be linear dependencies between outputs/tasks. Alternatively, if we want independent output dimensions, we can replace LMCVariationalStrategy with IndependentMultitaskVariationalStrategy:

[5]:
class IndependentMultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self):
        # Let's use a different set of inducing points for each task
        inducing_points = torch.rand(num_tasks, 16, 1)

        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_tasks])
        )

        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=4,
        )

        super().__init__(variational_strategy)

        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_tasks]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_tasks])),
            batch_shape=torch.Size([num_tasks])
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

Note that all the batch sizes for IndependentMultitaskVariationalStrategy are now num_tasks rather than num_latents.

Output modes

By default, LMCVariationalStrategy and IndependentMultitaskVariationalStrategy produce vector-valued outputs. In other words, they return a MultitaskMultivariateNormal distribution – containing all task values for each input.

This is similar to the ExactGP model described in the multitask GP regression tutorial.

[6]:
output = model(train_x)
print(output.__class__.__name__, output.event_shape)
MultitaskMultivariateNormal torch.Size([100, 4])

Alternatively, if each input is only associated with a single task, passing in the task_indices argument will specify which task to return for each input. The result will be a standard MultivariateNormal distribution – where each output corresponds to each input’s specified task.

This is similar to the ExactGP model described in the Hadamard multitask GP regression tutorial

[7]:
x = train_x[..., :5]
task_indices = torch.LongTensor([0, 1, 3, 2, 2])
output = model(x, task_indices=task_indices)
print(output.__class__.__name__, output.event_shape)
MultivariateNormal torch.Size([5])

Train the model

This code should look similar to the SVGP training code

[8]:
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
num_epochs = 1 if smoke_test else 500


model.train()
likelihood.train()

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr=0.1)

# Our loss object. We're using the VariationalELBO, which essentially just computes the ELBO
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))

# We use more CG iterations here because the preconditioner introduced in the NeurIPS paper seems to be less
# effective for VI.
epochs_iter = tqdm.tqdm_notebook(range(num_epochs), desc="Epoch")
for i in epochs_iter:
    # Within each iteration, we will go over each minibatch of data
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    epochs_iter.set_postfix(loss=loss.item())
    loss.backward()
    optimizer.step()
/home/gpleiss/miniconda3/envs/gpytorch/lib/python3.7/site-packages/ipykernel_launcher.py:20: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`

Make predictions with the model

[9]:
# Set into eval mode
model.eval()
likelihood.eval()

# Initialize plots
fig, axs = plt.subplots(1, num_tasks, figsize=(4 * num_tasks, 3))

# Make predictions
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    test_x = torch.linspace(0, 1, 51)
    predictions = likelihood(model(test_x))
    mean = predictions.mean
    lower, upper = predictions.confidence_region()

for task, ax in enumerate(axs):
    # Plot training data as black stars
    ax.plot(train_x.detach().numpy(), train_y[:, task].detach().numpy(), 'k*')
    # Predictive mean as blue line
    ax.plot(test_x.numpy(), mean[:, task].numpy(), 'b')
    # Shade in confidence
    ax.fill_between(test_x.numpy(), lower[:, task].numpy(), upper[:, task].numpy(), alpha=0.5)
    ax.set_ylim([-3, 3])
    ax.legend(['Observed Data', 'Mean', 'Confidence'])
    ax.set_title(f'Task {task + 1}')

fig.tight_layout()
None
../../_images/examples_04_Variational_and_Approximate_GPs_SVGP_Multitask_GP_Regression_18_0.png