SVDKL (Stochastic Variational Deep Kernel Learning) on CIFAR10/100

In this notebook, we’ll demonstrate the steps necessary to train a medium sized DenseNet (https://arxiv.org/abs/1608.06993) on either of two popularly used benchmark dataset in computer vision (CIFAR10 and CIFAR100). We’ll be training the DKL model entirely end to end using the standard 300 Epoch training schedule and SGD.

This notebook is largely for tutorial purposes. If your goal is just to get (for example) a trained DKL + CIFAR100 model, we recommend that you move this code to a simple python script and run that, rather than training directly out of a python notebook. We find that training is just a bit faster out of a python notebook. We also of course recommend that you increase the size of the DenseNet used to a full sized model if you would like to achieve state of the art performance.

Furthermore, because this notebook involves training an actually reasonably large neural network, it is strongly recommended that you have a decent GPU available for this, as with all large deep learning models.

[1]:
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import MultiStepLR
import torch.nn.functional as F
from torch import nn
import torch
import os
import torchvision.datasets as dset
import torchvision.transforms as transforms
import gpytorch
import math
import tqdm

Set up data augmentation

The first thing we’ll do is set up some data augmentation transformations to use during training, as well as some basic normalization to use during both training and testing. We’ll use random crops and flips to train the model, and do basic normalization at both training time and test time. To accomplish these transformations, we use standard torchvision transforms.

[2]:
normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
aug_trans = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
common_trans = [transforms.ToTensor(), normalize]
train_compose = transforms.Compose(aug_trans + common_trans)
test_compose = transforms.Compose(common_trans)

Create DataLoaders

Next, we create dataloaders for the selected dataset using the built in torchvision datasets. The cell below will download either the cifar10 or cifar100 dataset, depending on which choice is made. The default here is cifar10, however training is just as fast on either dataset.

After downloading the datasets, we create standard torch.utils.data.DataLoaders for each dataset that we will be using to get minibatches of augmented data.

[3]:
dataset = "cifar10"


if ('CI' in os.environ):  # this is for running the notebook in our testing framework
    train_set = torch.utils.data.TensorDataset(torch.randn(8, 3, 32, 32), torch.rand(8).round().long())
    test_set = torch.utils.data.TensorDataset(torch.randn(4, 3, 32, 32), torch.rand(4).round().long())
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=2, shuffle=False)
    num_classes = 2
elif dataset == 'cifar10':
    train_set = dset.CIFAR10('data', train=True, transform=train_compose, download=True)
    test_set = dset.CIFAR10('data', train=False, transform=test_compose)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)
    num_classes = 10
elif dataset == 'cifar100':
    train_set = dset.CIFAR100('data', train=True, transform=train_compose, download=True)
    test_set = dset.CIFAR100('data', train=False, transform=test_compose)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)
    num_classes = 100
else:
    raise RuntimeError('dataset must be one of "cifar100" or "cifar10"')
Files already downloaded and verified

Creating the DenseNet Model

With the data loaded, we can move on to defining our DKL model. A DKL model consists of three components: the neural network, the Gaussian process layer used after the neural network, and the Softmax likelihood.

The first step is defining the neural network architecture. To do this, we use a slightly modified version of the DenseNet available in the standard PyTorch package. Specifically, we modify it to remove the softmax layer, since we’ll only be needing the final features extracted from the neural network.

[4]:
from densenet import DenseNet

class DenseNetFeatureExtractor(DenseNet):
    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.avg_pool2d(out, kernel_size=self.avgpool_size).view(features.size(0), -1)
        return out

feature_extractor = DenseNetFeatureExtractor(block_config=(6, 6, 6), num_classes=num_classes)
num_features = feature_extractor.classifier.in_features

Creating the GP Layer

In the next cell, we create the layer of Gaussian process models that are called after the neural network. In this case, we’ll be using one GP per feature, as in the SV-DKL paper. The outputs of these Gaussian processes will the be mixed in the softmax likelihood.

[5]:
class GaussianProcessLayer(gpytorch.models.ApproximateGP):
    def __init__(self, num_dim, grid_bounds=(-10., 10.), grid_size=64):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            num_inducing_points=grid_size, batch_shape=torch.Size([num_dim])
        )

        # Our base variational strategy is a GridInterpolationVariationalStrategy,
        # which places variational inducing points on a Grid
        # We wrap it with a IndependentMultitaskVariationalStrategy so that our output is a vector-valued GP
        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.GridInterpolationVariationalStrategy(
                self, grid_size=grid_size, grid_bounds=[grid_bounds],
                variational_distribution=variational_distribution,
            ), num_tasks=num_dim,
        )
        super().__init__(variational_strategy)

        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(
                lengthscale_prior=gpytorch.priors.SmoothedBoxPrior(
                    math.exp(-1), math.exp(1), sigma=0.1, transform=torch.exp
                )
            )
        )
        self.mean_module = gpytorch.means.ConstantMean()
        self.grid_bounds = grid_bounds

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)

Creating the full SVDKL Model

With both the DenseNet feature extractor and GP layer defined, we can put them together in a single module that simply calls one and then the other, much like building any Sequential neural network in PyTorch. This completes defining our DKL model.

[6]:
class DKLModel(gpytorch.Module):
    def __init__(self, feature_extractor, num_dim, grid_bounds=(-10., 10.)):
        super(DKLModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.gp_layer = GaussianProcessLayer(num_dim=num_dim, grid_bounds=grid_bounds)
        self.grid_bounds = grid_bounds
        self.num_dim = num_dim

        # This module will scale the NN features so that they're nice values
        self.scale_to_bounds = gpytorch.utils.grid.ScaleToBounds(self.grid_bounds[0], self.grid_bounds[1])

    def forward(self, x):
        features = self.feature_extractor(x)
        features = self.scale_to_bounds(features)
        # This next line makes it so that we learn a GP for each feature
        features = features.transpose(-1, -2).unsqueeze(-1)
        res = self.gp_layer(features)
        return res

model = DKLModel(feature_extractor, num_dim=num_features)
likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=model.num_dim, num_classes=num_classes)


# If you run this example without CUDA, I hope you like waiting!
if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

Defining Training and Testing Code

Next, we define the basic optimization loop and testing code. This code is entirely analogous to the standard PyTorch training loop. We create a torch.optim.SGD optimizer with the parameters of the neural network on which we apply the standard amount of weight decay suggested from the paper, the parameters of the Gaussian process (from which we omit weight decay, as L2 regualrization on top of variational inference is not necessary), and the mixing parameters of the Softmax likelihood.

We use the standard learning rate schedule from the paper, where we decrease the learning rate by a factor of ten 50% of the way through training, and again at 75% of the way through training.

[7]:
n_epochs = 1
lr = 0.1
optimizer = SGD([
    {'params': model.feature_extractor.parameters(), 'weight_decay': 1e-4},
    {'params': model.gp_layer.hyperparameters(), 'lr': lr * 0.01},
    {'params': model.gp_layer.variational_parameters()},
    {'params': likelihood.parameters()},
], lr=lr, momentum=0.9, nesterov=True, weight_decay=0)
scheduler = MultiStepLR(optimizer, milestones=[0.5 * n_epochs, 0.75 * n_epochs], gamma=0.1)
mll = gpytorch.mlls.VariationalELBO(likelihood, model.gp_layer, num_data=len(train_loader.dataset))


def train(epoch):
    model.train()
    likelihood.train()

    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc=f"(Epoch {epoch}) Minibatch")
    with gpytorch.settings.num_likelihood_samples(8):
        for data, target in minibatch_iter:
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = -mll(output, target)
            loss.backward()
            optimizer.step()
            minibatch_iter.set_postfix(loss=loss.item())

def test():
    model.eval()
    likelihood.eval()

    correct = 0
    with torch.no_grad(), gpytorch.settings.num_likelihood_samples(16):
        for data, target in test_loader:
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            output = likelihood(model(data))  # This gives us 16 samples from the predictive distribution
            pred = output.probs.mean(0).argmax(-1)  # Taking the mean over all of the sample we've drawn
            correct += pred.eq(target.view_as(pred)).cpu().sum()
    print('Test set: Accuracy: {}/{} ({}%)'.format(
        correct, len(test_loader.dataset), 100. * correct / float(len(test_loader.dataset))
    ))

We are now ready to train the model. At the end of each Epoch we report the current test loss and accuracy, and we save a checkpoint model out to a file.

[8]:
for epoch in range(1, n_epochs + 1):
    with gpytorch.settings.use_toeplitz(False):
        train(epoch)
        test()
    scheduler.step()
    state_dict = model.state_dict()
    likelihood_state_dict = likelihood.state_dict()
    torch.save({'model': state_dict, 'likelihood': likelihood_state_dict}, 'dkl_cifar_checkpoint.dat')

Test set: Accuracy: 4583/10000 (45.83000183105469%)

We cut training short here. With more training - the network will get as good of accuracy as a standard neural network.