Deep Sigma Point Processes

In this notebook, we provide a GPyTorch implementation of Deep Sigma Point Processes (DSPPs), as described in Jankowiak et al., 2020 (http://www.auai.org/uai2020/proceedings/339_main_paper.pdf).

It will be useful to compare and contrast this notebook with our standard Deep GP notebook, as the computational structure of a Deep GP and a DSPP are quite similar.

[1]:
import gpytorch
import torch
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean, LinearMean
from gpytorch.kernels import ScaleKernel, MaternKernel
from gpytorch.variational import VariationalStrategy, BatchDecoupledVariationalStrategy
from gpytorch.variational import MeanFieldVariationalDistribution
from gpytorch.models.deep_gps.dspp import DSPPLayer, DSPP
import gpytorch.settings as settings

Basic settings

In the next cell, we define some basic settings that can be tuned. The only hyperparameter that is DSPP specific is num_quadrature_sites, which effectively determines the number of mixtures that the output distribution will have. hidden_dim controls the width of the hidden GP layer. The other parameters are standard optimization hyperparameters.

[2]:
import os

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)

batch_size = 1000                 # Size of minibatch
milestones = [20, 150, 300]       # Epochs at which we will lower the learning rate by a factor of 0.1
num_inducing_pts = 300            # Number of inducing points in each hidden layer
num_epochs = 400                  # Number of epochs to train for
initial_lr = 0.01                 # Initial learning rate
hidden_dim = 3                    # Number of GPs (i.e., the width) in the hidden layer.
num_quadrature_sites = 8          # Number of quadrature sites (see paper for a description of this. 5-10 generally works well).

## Modified settings for smoke test purposes
num_epochs = num_epochs if not smoke_test else 1

Loading Data

For this example notebook, we’ll be using the bike UCI dataset used in the paper. Running the next cell downloads a copy of the dataset. We will be using the same normalization, randomization, and train/test splitting scheme as used in the paper, although for this demo notebook we do not use a validation set as we won’t be tuning any hyperparameters.

[3]:
import urllib.request
from scipy.io import loadmat
from math import floor


if not smoke_test and not os.path.isfile('../bike.mat'):
    print('Downloading \'bike\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1pR1H9ee4U89C1y_uYe9qAypKsHs1EL5I', '../bike.mat')

if smoke_test:  # this is for running the notebook in our testing framework
    X, y = torch.randn(1000, 3), torch.randn(1000)
else:
    data = torch.Tensor(loadmat('bike.mat')['data'])

    # Map features to [-1, 1]
    X = data[:, :-1]
    X = X - X.min(0)[0]
    X = 2.0 * (X / X.max(0)[0]) - 1.0

    # Z-score labels
    y = data[:, -1]
    y -= y.mean()
    y /= y.std()

shuffled_indices = torch.randperm(X.size(0))
X = X[shuffled_indices, :]
y = y[shuffled_indices]

train_n = int(floor(0.75 * X.size(0)))

train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()
test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()

print(train_x.shape, train_y.shape, test_x.shape, test_y.shape)
torch.Size([13034, 17]) torch.Size([13034]) torch.Size([4345, 17]) torch.Size([4345])

Create PyTorch DataLoader objects

As we will be training and predicting on minibatches, we use the standard PyTorch TensorDataset and DataLoader framework to handle getting batches of data.

[4]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Initialize Hidden Layer Inducing Points

Just to match the setup of the paper, we initialize the inducing points for each GP in the hidden layer by using kmeans clustering. The DSPPHiddenLayer class as defined below can also take num_inducing=300, inducing_points=None as arguments to randomly initialize the inducing points. However, we find that using kmeans to initialize can improve optimization in most cases.

[5]:
from scipy.cluster.vq import kmeans2

# Use k-means to initialize inducing points (only helpful for the first layer)
inducing_points = (train_x[torch.randperm(min(1000 * 100, train_n))[0:num_inducing_pts], :])
inducing_points = inducing_points.clone().data.cpu().numpy()
inducing_points = torch.tensor(kmeans2(train_x.data.cpu().numpy(),
                               inducing_points, minit='matrix')[0])

if torch.cuda.is_available():
    inducing_points = inducing_points.cuda()

Create The DSPPHiddenLayer Class

The next cell is the most important in the notebook. It will likely be instructive to compare the code below to the analogous code cell in our Deep GP notebook, as they are essentially the same. The only difference is some code at the start to handle the fact that we may pass in prespecified inducing point locations, rather than always initializing them randomly.

Regardless, the best way to think of a DSPP (or DGP) hidden layer class is as a standard GPyTorch variational GP class that has two key aspects:

  1. It has a batch shape equal to output_dims. In other words, the way we handle a layer of multiple GPs is with a batch dimension. This means that inducing_points, kernel hyperparameters, etc should all have batch_shape=torch.Size([output_dims]).
  2. It extends DSPPLayer rather than ApproximateGP.

These are really the only two differences. A DSPPLayer / DGPLayer will still define a variational distribution and strategy, a prior mean and covariance function, and forward is still responsible for returning the prior.

[6]:
class DSPPHiddenLayer(DSPPLayer):
    def __init__(self, input_dims, output_dims, num_inducing=300, inducing_points=None, mean_type='constant', Q=8):
        if inducing_points is not None and output_dims is not None and inducing_points.dim() == 2:
            # The inducing points were passed in, but the shape doesn't match the number of GPs in this layer.
            # Let's assume we wanted to use the same inducing point initialization for each GP in the layer,
            # and expand the inducing points to match this.
            inducing_points = inducing_points.unsqueeze(0).expand((output_dims,) + inducing_points.shape)
            inducing_points = inducing_points.clone() + 0.01 * torch.randn_like(inducing_points)
        if inducing_points is None:
            # No inducing points were specified, let's just initialize them randomly.
            if output_dims is None:
                # An output_dims of None implies there is only one GP in this layer
                # (e.g., the last layer for univariate regression).
                inducing_points = torch.randn(num_inducing, input_dims)
            else:
                inducing_points = torch.randn(output_dims, num_inducing, input_dims)
        else:
            # Get the number of inducing points from the ones passed in.
            num_inducing = inducing_points.size(-2)

        # Let's use mean field / diagonal covariance structure.
        variational_distribution = MeanFieldVariationalDistribution(
            num_inducing_points=num_inducing,
            batch_shape=torch.Size([output_dims]) if output_dims is not None else torch.Size([])
        )

        # Standard variational inference.
        variational_strategy = VariationalStrategy(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=True
        )

        batch_shape = torch.Size([]) if output_dims is None else torch.Size([output_dims])

        super(DSPPHiddenLayer, self).__init__(variational_strategy, input_dims, output_dims, Q)

        if mean_type == 'constant':
            # We'll use a constant mean for the final output layer.
            self.mean_module = ConstantMean(batch_shape=batch_shape)
        elif mean_type == 'linear':
            # As in Salimbeni et al. 2017, we find that using a linear mean for the hidden layer improves performance.
            self.mean_module = LinearMean(input_dims, batch_shape=batch_shape)

        self.covar_module = ScaleKernel(MaternKernel(batch_shape=batch_shape, ard_num_dims=input_dims),
                                        batch_shape=batch_shape, ard_num_dims=None)

    def forward(self, x, mean_input=None, **kwargs):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

Create the DSPP Class

The below creates a DSPP container that is virtually identical to the one used in the DGP setting. All it is responsible for is insantiating the layers of the DSPP (in this case, one hidden layer and one output layer), and then defining a forward method that passes data through both layers.

As in the DGP example notebook, we also define a predict method purely for convenience that takes a DataLoader and returns predictions for every example in that DataLoader.

[7]:
class TwoLayerDSPP(DSPP):
    def __init__(self, train_x_shape, inducing_points, num_inducing, hidden_dim=3, Q=3):
        hidden_layer = DSPPHiddenLayer(
            input_dims=train_x_shape[-1],
            output_dims=hidden_dim,
            mean_type='linear',
            inducing_points=inducing_points,
            Q=Q,
        )
        last_layer = DSPPHiddenLayer(
            input_dims=hidden_layer.output_dims,
            output_dims=None,
            mean_type='constant',
            inducing_points=None,
            num_inducing=num_inducing,
            Q=Q,
        )

        likelihood = GaussianLikelihood()

        super().__init__(Q)
        self.likelihood = likelihood
        self.last_layer = last_layer
        self.hidden_layer = hidden_layer

    def forward(self, inputs, **kwargs):
        hidden_rep1 = self.hidden_layer(inputs, **kwargs)
        output = self.last_layer(hidden_rep1, **kwargs)
        return output

    def predict(self, loader):
        with settings.fast_computations(log_prob=False, solves=False), torch.no_grad():
            mus, variances, lls = [], [], []
            for x_batch, y_batch in loader:
                preds = self.likelihood(self(x_batch, mean_input=x_batch))
                mus.append(preds.mean.cpu())
                variances.append(preds.variance.cpu())

                # Compute test log probability. The output of a DSPP is a weighted mixture of Q Gaussians,
                # with the Q weights specified by self.quad_weight_grid. The below code computes the log probability of each
                # test point under this mixture.

                # Step 1: Get log marginal for each Gaussian in the output mixture.
                base_batch_ll = self.likelihood.log_marginal(y_batch, self(x_batch))

                # Step 2: Weight each log marginal by its quadrature weight in log space.
                deep_batch_ll = self.quad_weights.unsqueeze(-1) + base_batch_ll

                # Step 3: Take logsumexp over the mixture dimension, getting test log prob for each datapoint in the batch.
                batch_log_prob = deep_batch_ll.logsumexp(dim=0)
                lls.append(batch_log_prob.cpu())

        return torch.cat(mus, dim=-1), torch.cat(variances, dim=-1), torch.cat(lls, dim=-1)
[8]:
model = TwoLayerDSPP(
    train_x.shape,
    inducing_points,
    num_inducing=num_inducing_pts,
    hidden_dim=hidden_dim,
    Q=num_quadrature_sites
)

if torch.cuda.is_available():
    model.cuda()

model.train()
[8]:
TwoLayerDSPP(
  (likelihood): GaussianLikelihood(
    (noise_covar): HomoskedasticNoise(
      (raw_noise_constraint): GreaterThan(1.000E-04)
    )
  )
  (last_layer): DSPPHiddenLayer(
    (variational_strategy): VariationalStrategy(
      (_variational_distribution): MeanFieldVariationalDistribution()
    )
    (mean_module): ConstantMean()
    (covar_module): ScaleKernel(
      (base_kernel): MaternKernel(
        (raw_lengthscale_constraint): Positive()
      )
      (raw_outputscale_constraint): Positive()
    )
  )
  (hidden_layer): DSPPHiddenLayer(
    (variational_strategy): VariationalStrategy(
      (_variational_distribution): MeanFieldVariationalDistribution()
    )
    (mean_module): LinearMean()
    (covar_module): ScaleKernel(
      (base_kernel): MaternKernel(
        (raw_lengthscale_constraint): Positive()
      )
      (raw_outputscale_constraint): Positive()
    )
  )
)
[9]:
from gpytorch.mlls import DeepPredictiveLogLikelihood

adam = torch.optim.Adam([{'params': model.parameters()}], lr=initial_lr, betas=(0.9, 0.999))
sched = torch.optim.lr_scheduler.MultiStepLR(adam, milestones=milestones, gamma=0.1)


# The "beta" parameter here corresponds to \beta_{reg} from the paper, and represents a scaling factor on the KL divergence
# portion of the loss.
objective = DeepPredictiveLogLikelihood(model.likelihood, model, num_data=train_n, beta=0.05)

Train the Model

Below is a standard minibatch training loop.

[ ]:
import tqdm

epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc="Epoch")

for i in epochs_iter:
    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc="Minibatch", leave=False)
    for x_batch, y_batch in minibatch_iter:
        adam.zero_grad()
        output = model(x_batch)
        loss = -objective(output, y_batch)
        loss.backward()
        adam.step()
    sched.step()

Make Predictions, compute RMSE and Test NLL

[26]:
model.eval()
means, vars, ll = model.predict(test_loader)
weights = model.quad_weights.unsqueeze(-1).exp().cpu()
# `means` currently contains the predictive output from each Gaussian in the mixture.
# To get the total mean output, we take a weighted sum of these means over the quadrature weights.
rmse = ((weights * means).sum(0) - test_y.cpu()).pow(2.0).mean().sqrt().item()
ll = ll.mean().item()

print('RMSE: ', rmse, 'Test NLL: ', -ll)
RMSE:  0.04274941563606262 Test NLL:  -1.690010404586792
[ ]: