# Clustered Multitask GP (w/ Pyro/GPyTorch High-Level Interface)¶

## Introduction¶

In this example, we use the Pyro integration for a GP model with additional latent variables.

We are modelling a multitask GP in this example. Rather than assuming a linear correlation among the different tasks, we assume that there is cluster structure for the different tasks. Let’s assume there are $$k$$ different clusters of tasks. The generative model for task $$i$$ is:

$p(\mathbf y_i \mid \mathbf x_i) = \int \sum_{z_i=1}^k p(\mathbf y_i \mid \mathbf f (\mathbf x_i), z_i) \: p(z_i) \: p(\mathbf f (\mathbf x_i) ) \: d \mathbf f$

where $$z_i$$ is the cluster assignment for task $$i$$. There are therefore $$k$$ latent functions $$\mathbf f = [f_1 \ldots f_k]$$, each modelled by a GP, representing each cluster.

Our goal is therefore to infer:

• The latent functions $$f_1 \ldots f_k$$
• The cluster assignments $$z_i$$ for each task
:

import math
import torch
import pyro
import gpytorch
from matplotlib import pyplot as plt

%matplotlib inline

# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)


The standard GPyTorch variational objects will take care of inferring the latent functions $$f_1 \ldots f_k$$. However, we do need to add the additional latent variables $$z_i$$ to the models. We will do so by creating a custom likelihood that models:

$\sum_{z_i=1}^k p(\mathbf y_i \mid \mathbf f (\mathbf x_i), z_i) \: p(z_i)$

GPyTorch’s likelihoods are capable of modeling additional latent variables. Our custom likelihood needs to define the following three functions:

• pyro_model (needs to call through to super().pyro_model at the end), which defines the prior distribution for additional latent variables
• pyro_guide (needs to call through to super().pyro_guide at the end), which defines the variational (guide) distribution for additional latent variables
• forward, which defines the observation distributions conditioned on \mathbf f (\mathbf x_i) and any additional latent variables.

### The pyro_model function¶

For each task, we will model the cluster assignment with a OneHotCategorical variable, where each cluster has equal probability. The pyro_model function will make a pyro.sample call to this prior distribution and then call the super method:

# self.prior_cluster_logits = torch.zeros(num_tasks, num_clusters)

def pyro_model(self, function_dist, target):
cluster_assignment_samples = pyro.sample(
self.name_prefix + ".cluster_logits",  # self.name_prefix is added by PyroGP
pyro.distributions.OneHotCategorical(logits=self.prior_cluster_logits).to_event(1)
)
return super().pyro_model(
function_dist,
target,
cluster_assignment_samples=cluster_assignment_samples
)


Note that we are adding an additional argument cluster_assignment_samples to the super().pyro_model call. This will pass the cluster assignment samples to the forward call, which is necessary for inference.

### The pyro_guide function¶

For each task, the variational (guide) diustribution will also be a OneHotCategorical variable, which will be defined by the parameter self.variational_cluster_logits. The pyro_guide function will make a pyro.sample call to this prior distribution and then call the super method:

def pyro_guide(self, function_dist, target):
pyro.sample(
self.name_prefix + ".cluster_logits",  # self.name_prefix is added by PyroGP
pyro.distributions.OneHotCategorical(logits=self.variational_cluster_logits).to_event(1)
)
return super().pyro_guide(function_dist, target)


Note that we are adding an additional argument cluster_assignment_samples to the super().pyro_model call. This will pass the cluster assignment samples to the forward call, which is necessary for inference.

### The forward function¶

The pyro_model fuction passes the additional keyword argument cluster_assignment_samples to the forward call. Therefore, our forward method will define the conditional probability $$p(\mathbf y_i \mid \mathbf f(\mathbf x), z_i)$$, where $$\mathbf f(\mathbf x)$$ corresponds to the variable function_samples and $$z_i$$ corresponds to the variable cluster_assignment_samples.

In our example $$p(\mathbf y_i \mid \mathbf f(\mathbf x), z_i)$$ corresponds to a Gaussian noise model.

# self.raw_noise is the Gaussian noise parameter
# function_samples is n x k
# cluster_assignment_samples is k x t, where t is the number of tasks

def forward(self, function_samples, cluster_assignment_samples):
return pyro.distributions.Normal(
loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1),
scale=torch.nn.functional.softplus(self.raw_noise).sqrt()
).to_event(1)
# The to_event call is necessary because we are returning a multitask distribution,
# where each task dimension corresponds to each of the t tasks


This is all we need for inference! However, if we want to use this model to make predictions, the cluster_assignment_samples keyword argument will not be passed into the function. Therefore, we need to make sure that forward can handle both inference and predictions:

def forward(self, function_samples, cluster_assignment_samples=None):
if cluster_assignment_samples is None:
# We'll get here at prediction time
# We'll use the variational distribution when making predictions
cluster_assignment_samples = pyro.sample(
self.name_prefix + ".cluster_logits", self._cluster_dist(self.variational_cluster_logits)
)

return pyro.distributions.Normal(
loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1),
scale=torch.nn.functional.softplus(self.raw_noise).sqrt()
).to_event(1)

:

class ClusterGaussianLikelihood(gpytorch.likelihoods.Likelihood):
super().__init__()

# These are parameters/buffers for the cluster assignment latent variables

# The Gaussian observational noise
self.register_parameter("raw_noise", torch.nn.Parameter(torch.tensor(0.0)))

# Other info
self.num_clusters = num_clusters
self.max_plate_nesting = 1

def pyro_guide(self, function_dist, target):
# Here we add the extra variational distribution for the cluster latent variable
pyro.sample(
self.name_prefix + ".cluster_logits",  # self.name_prefix is added by PyroGP
pyro.distributions.OneHotCategorical(logits=self.variational_cluster_logits).to_event(1)
)
return super().pyro_guide(function_dist, target)

def pyro_model(self, function_dist, target):
# Here we add the extra prior distribution for the cluster latent variable
cluster_assignment_samples = pyro.sample(
self.name_prefix + ".cluster_logits",  # self.name_prefix is added by PyroGP
pyro.distributions.OneHotCategorical(logits=self.prior_cluster_logits).to_event(1)
)
return super().pyro_model(function_dist, target, cluster_assignment_samples=cluster_assignment_samples)

def forward(self, function_samples, cluster_assignment_samples=None):
# For inference, cluster_assignment_samples will be passed in
# This bit of code is for when we use the likelihood in the predictive mode
if cluster_assignment_samples is None:
cluster_assignment_samples = pyro.sample(
self.name_prefix + ".cluster_logits", self._cluster_dist(self.variational_cluster_logits)
)

# Now we return the observational distribution, based on the function_samples and cluster_assignment_samples
res = pyro.distributions.Normal(
loc=(function_samples.unsqueeze(-2) * cluster_assignment_samples).sum(-1),
scale=torch.nn.functional.softplus(self.raw_noise).sqrt()
).to_event(1)
return res


## Constructing the PyroGP model¶

The PyroGP model is essentially the same as the model we used in the simple example, except for two changes

• We now will use our more complicated ClusterGaussianLikelihood
• The latent function should be vector valued to correspond to the k latent functions. As a result, we will learn a batched variational distribution, and use a IndependentMultitaskVariationalStrategy to convert the batched variational distribution into a MultitaskMultivariateNormal distribution.
:

class ClusterMultitaskGPModel(gpytorch.models.pyro.PyroGP):
def __init__(self, train_x, train_y, num_functions=2, reparam=False):
num_data = train_y.size(-2)

# Define all the variational stuff
inducing_points = torch.linspace(0, 1, 64).unsqueeze(-1)
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
num_inducing_points=inducing_points.size(-2),
batch_shape=torch.Size([num_functions])
)

# Here we're using a IndependentMultitaskVariationalStrategy - so that the output of the
# GP latent function is a MultitaskMultivariateNormal
gpytorch.variational.VariationalStrategy(self, inducing_points, variational_distribution),
)

# Standard initializtation
likelihood = ClusterGaussianLikelihood(train_y.size(-1), num_functions)
super().__init__(variational_strategy, likelihood, num_data=num_data, name_prefix=str(time.time()))
self.likelihood = likelihood
self.num_functions = num_functions

# Mean, covar
self.mean_module = gpytorch.means.ZeroMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
res = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
return res


This model can now be used to perform inference on cluster assignments, as well as make predictions using the inferred cluster assignments!