Hyperparameters in GPyTorch

The purpose of this notebook is to explain how GP hyperparameters in GPyTorch work, how they are handled, what options are available for constraints and priors, and how things may differ from other packages.

Note: This is a basic introduction to hyperparameters in GPyTorch. If you want to use GPyTorch hyperparameters with things like Pyro distributions, that will be covered in a less “basic usage” tutorial.

[1]:
# smoke_test (this makes sure this example notebook gets tested)

import math
import torch
import gpytorch
from matplotlib import pyplot as plt

from IPython.display import Markdown, display
def printmd(string):
    display(Markdown(string))

Defining an example model

In the next cell, we define our simple exact GP from the Simple GP Regression tutorial. We’ll be using this model to demonstrate certain aspects of hyperparameter creation.

[2]:
train_x = torch.linspace(0, 1, 100)
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2

# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

Viewing model hyperparameters

Let’s take a look at the model parameters. By “parameters”, here I mean explicitly objects of type torch.nn.Parameter that will have gradients filled in by autograd. To access these, there are two ways of doing this in torch. One way is to use model.state_dict(), which we demonstrate the use of for saving models here.

In the next cell we demonstrate another way to do this, by looping over the model.named_parameters() generator:

[3]:
for param_name, param in model.named_parameters():
    print(f'Parameter name: {param_name:42} value = {param.item()}')
Parameter name: likelihood.noise_covar.raw_noise           value = 0.0
Parameter name: mean_module.constant                       value = 0.0
Parameter name: covar_module.raw_outputscale               value = 0.0
Parameter name: covar_module.base_kernel.raw_lengthscale   value = 0.0

Raw vs Actual Parameters

The most important thing to note here is that the actual learned parameters of the model are things like raw_noise, raw_outputscale, raw_lengthscale, etc. The reason for this is that these parameters must be positive. This brings us to our next topic for parameters: constraints, and the difference between raw parameters and actual parameters.

In order to enforce positiveness and other constraints for hyperparameters, GPyTorch has raw parameters (e.g., model.covar_module.raw_outputscale) that are transformed to actual values via some constraint. Let’s take a look at the raw outputscale, its constraint, and the final value:

[4]:
raw_outputscale = model.covar_module.raw_outputscale
print('raw_outputscale, ', raw_outputscale)

# Three ways of accessing the raw outputscale constraint
print('\nraw_outputscale_constraint1', model.covar_module.raw_outputscale_constraint)

printmd('\n\n**Printing all model constraints...**\n')
for constraint_name, constraint in model.named_constraints():
    print(f'Constraint name: {constraint_name:55} constraint = {constraint}')

printmd('\n**Getting raw outputscale constraint from model...**')
print(model.constraint_for_parameter_name("covar_module.raw_outputscale"))


printmd('\n**Getting raw outputscale constraint from model.covar_module...**')
print(model.covar_module.constraint_for_parameter_name("raw_outputscale"))
raw_outputscale,  Parameter containing:
tensor(0., requires_grad=True)

raw_outputscale_constraint1 Positive()
Printing all model constraints…
Constraint name: likelihood.noise_covar.raw_noise_constraint             constraint = GreaterThan(1.000E-04)
Constraint name: covar_module.raw_outputscale_constraint                 constraint = Positive()
Constraint name: covar_module.base_kernel.raw_lengthscale_constraint     constraint = Positive()
Getting raw outputscale constraint from model…
Positive()
Getting raw outputscale constraint from model.covar_module…
Positive()

How do constraints work?

Constraints define transform and inverse_transform methods that turn raw parameters in to real ones. For a positive constraint, we expect the transformed values to always be positive. Let’s see:

[5]:
raw_outputscale = model.covar_module.raw_outputscale
constraint = model.covar_module.raw_outputscale_constraint

print('Transformed outputscale', constraint.transform(raw_outputscale))
print(constraint.inverse_transform(constraint.transform(raw_outputscale)))
print(torch.equal(constraint.inverse_transform(constraint.transform(raw_outputscale)), raw_outputscale))

print('Transform a bunch of negative tensors: ', constraint.transform(torch.tensor([-1., -2., -3.])))
Transformed outputscale tensor(0.6931, grad_fn=<SoftplusBackward>)
tensor(0., grad_fn=<LogBackward>)
True
Transform a bunch of negative tensors:  tensor([0.3133, 0.1269, 0.0486])

Convenience Getters/Setters for Transformed Values

Because dealing with raw parameter values is annoying (e.g., we might know what a noise variance of 0.01 means, but maybe not a raw_noise of -2.791), virtually all built in GPyTorch modules that define raw parameters define convenience getters and setters for dealing with transformed values directly.

In the next cells, we demonstrate the “inconvenient way” and the “convenient” way of getting and setting the outputscale.

[6]:
# Recreate model to reset outputscale
model = ExactGPModel(train_x, train_y, likelihood)

# Inconvenient way of getting true outputscale
raw_outputscale = model.covar_module.raw_outputscale
constraint = model.covar_module.raw_outputscale_constraint
outputscale = constraint.transform(raw_outputscale)
print(f'Actual outputscale: {outputscale.item()}')

# Inconvenient way of setting true outputscale
model.covar_module.raw_outputscale.data.fill_(constraint.inverse_transform(torch.tensor(2.)))
raw_outputscale = model.covar_module.raw_outputscale
outputscale = constraint.transform(raw_outputscale)
print(f'Actual outputscale after setting: {outputscale.item()}')
Actual outputscale: 0.6931471824645996
Actual outputscale after setting: 2.0

Ouch, that is ugly! Fortunately, there is a better way:

[7]:
# Recreate model to reset outputscale
model = ExactGPModel(train_x, train_y, likelihood)

# Convenient way of getting true outputscale
print(f'Actual outputscale: {model.covar_module.outputscale}')

# Convenient way of setting true outputscale
model.covar_module.outputscale = 2.
print(f'Actual outputscale after setting: {model.covar_module.outputscale}')
Actual outputscale: 0.6931471824645996
Actual outputscale after setting: 2.0

Changing Parameter Constraints

If we look at the actual noise of the model, GPyTorch defines a default lower bound of 1e-4 for the noise variance:

[8]:
print(f'Actual noise value: {likelihood.noise}')
print(f'Noise constraint: {likelihood.noise_covar.raw_noise_constraint}')
Actual noise value: tensor([0.6932], grad_fn=<AddBackward0>)
Noise constraint: GreaterThan(1.000E-04)

We can change the noise constraint either on the fly or when the likelihood is created:

[9]:
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1e-3))
print(f'Noise constraint: {likelihood.noise_covar.raw_noise_constraint}')

## Changing the constraint after the module has been created
likelihood.noise_covar.register_constraint("raw_noise", gpytorch.constraints.Positive())
print(f'Noise constraint: {likelihood.noise_covar.raw_noise_constraint}')
Noise constraint: GreaterThan(1.000E-03)
Noise constraint: Positive()

Priors

In GPyTorch, priors are things you register to the model that act on any arbitrary function of any parameter. Like constraints, these can usually be defined either when you create an object (like a Kernel or Likelihood), or set afterwards on the fly.

Here are some examples:

[10]:
# Registers a prior on the sqrt of the noise parameter
# (e.g., a prior for the noise standard deviation instead of variance)
likelihood.noise_covar.register_prior(
    "noise_std_prior",
    gpytorch.priors.NormalPrior(0, 1),
    lambda module: module.noise.sqrt()
)

# Create a GaussianLikelihood with a normal prior for the noise
likelihood = gpytorch.likelihoods.GaussianLikelihood(
    noise_constraint=gpytorch.constraints.GreaterThan(1e-3),
    noise_prior=gpytorch.priors.NormalPrior(0, 1)
)

Putting it Together

In the next cell, we augment our ExactGP definition to place several priors over hyperparameters and tighter constraints when creating the model.

[11]:
# We will use the simplest form of GP model, exact inference
class FancyGPWithPriors(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(FancyGPWithPriors, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()

        lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
        outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)

        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(
                lengthscale_prior=lengthscale_prior,
            ),
            outputscale_prior=outputscale_prior
        )

        # Initialize lengthscale and outputscale to mean of priors
        self.covar_module.base_kernel.lengthscale = lengthscale_prior.mean
        self.covar_module.outputscale = outputscale_prior.mean

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood(
    noise_constraint=gpytorch.constraints.GreaterThan(1e-2),
)

model = FancyGPWithPriors(train_x, train_y, likelihood)

Initializing hyperparameters in One Call

For convenience, GPyTorch modules also define an initialize method that allow you to update a full dictionary of parameters on submodules. For example:

[33]:
hypers = {
    'likelihood.noise_covar.noise': torch.tensor(1.),
    'covar_module.base_kernel.lengthscale': torch.tensor(0.5),
    'covar_module.outputscale': torch.tensor(2.),
}

model.initialize(**hypers)
print(
    model.likelihood.noise_covar.noise.item(),
    model.covar_module.base_kernel.lengthscale.item(),
    model.covar_module.outputscale.item()
)
1.0000001192092896 0.4999999701976776 2.0