Saving and Loading Models

In this bite-sized notebook, we’ll go over how to save and load models. In general, the process is the same as for any PyTorch module.

[2]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

Saving a Simple Model

First, we define a GP Model that we’d like to save. The model used below is the same as the model from our Simple GP Regression tutorial.

[3]:
train_x = torch.linspace(0, 1, 100)
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2
[4]:
# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

Change Model State

To demonstrate model saving, we change the hyperparameters from the default values below. For more information on what is happening here, see our tutorial notebook on Initializing Hyperparameters.

[6]:
model.covar_module.outputscale = 1.2
model.covar_module.base_kernel.lengthscale = 2.2

Getting Model State

To get the full state of a GPyTorch model, simply call state_dict as you would on any PyTorch model. Note that the state dict contains raw parameter values. This is because these are the actual torch.nn.Parameters that are learned in GPyTorch. Again see our notebook on hyperparamters for more information on this.

[11]:
model.state_dict()
[11]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('mean_module.constant', tensor([0.])),
             ('covar_module.raw_outputscale', tensor(0.8416)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]]))])

Saving Model State

The state dictionary above represents all traininable parameters for the model. Therefore, we can save this to a file as follows:

[12]:
torch.save(model.state_dict(), 'model_state.pth')

Loading Model State

Next, we load this state in to a new model and demonstrate that the parameters were updated correctly.

[13]:
state_dict = torch.load('model_state.pth')
model = ExactGPModel(train_x, train_y, likelihood)  # Create a new GP model

model.load_state_dict(state_dict)
[13]:
IncompatibleKeys(missing_keys=[], unexpected_keys=[])
[15]:
model.state_dict()
[15]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('mean_module.constant', tensor([0.])),
             ('covar_module.raw_outputscale', tensor(0.8416)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]]))])

A More Complex Example

Next we demonstrate this same principle on a more complex exact GP where we have a simple feed forward neural network feature extractor as part of the model.

[22]:
class GPWithNNFeatureExtractor(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPWithNNFeatureExtractor, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

        self.feature_extractor = torch.nn.Sequential(
            torch.nn.Linear(1, 2),
            torch.nn.BatchNorm1d(2),
            torch.nn.ReLU(),
            torch.nn.Linear(2, 2),
            torch.nn.BatchNorm1d(2),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GPWithNNFeatureExtractor(train_x, train_y, likelihood)

Getting Model State

In the next cell, we once again print the model state via model.state_dict(). As you can see, the state is substantially more complex, as the model now includes our neural network parameters. Nevertheless, saving and loading is straight forward.

[23]:
model.state_dict()
[23]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('mean_module.constant', tensor([0.])),
             ('covar_module.raw_outputscale', tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
             ('feature_extractor.0.weight', tensor([[-0.9135],
                      [-0.5942]])),
             ('feature_extractor.0.bias', tensor([ 0.9119, -0.0663])),
             ('feature_extractor.1.weight', tensor([0.2263, 0.2209])),
             ('feature_extractor.1.bias', tensor([0., 0.])),
             ('feature_extractor.1.running_mean', tensor([0., 0.])),
             ('feature_extractor.1.running_var', tensor([1., 1.])),
             ('feature_extractor.1.num_batches_tracked', tensor(0)),
             ('feature_extractor.3.weight', tensor([[-0.6375, -0.6466],
                      [-0.0563, -0.4695]])),
             ('feature_extractor.3.bias', tensor([-0.1247,  0.0803])),
             ('feature_extractor.4.weight', tensor([0.0466, 0.7248])),
             ('feature_extractor.4.bias', tensor([0., 0.])),
             ('feature_extractor.4.running_mean', tensor([0., 0.])),
             ('feature_extractor.4.running_var', tensor([1., 1.])),
             ('feature_extractor.4.num_batches_tracked', tensor(0))])
[24]:
torch.save(model.state_dict(), 'my_gp_with_nn_model.pth')
state_dict = torch.load('my_gp_with_nn_model.pth')
model = GPWithNNFeatureExtractor(train_x, train_y, likelihood)
model.load_state_dict(state_dict)
[24]:
IncompatibleKeys(missing_keys=[], unexpected_keys=[])
[25]:
model.state_dict()
[25]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('mean_module.constant', tensor([0.])),
             ('covar_module.raw_outputscale', tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
             ('feature_extractor.0.weight', tensor([[-0.9135],
                      [-0.5942]])),
             ('feature_extractor.0.bias', tensor([ 0.9119, -0.0663])),
             ('feature_extractor.1.weight', tensor([0.2263, 0.2209])),
             ('feature_extractor.1.bias', tensor([0., 0.])),
             ('feature_extractor.1.running_mean', tensor([0., 0.])),
             ('feature_extractor.1.running_var', tensor([1., 1.])),
             ('feature_extractor.1.num_batches_tracked', tensor(0)),
             ('feature_extractor.3.weight', tensor([[-0.6375, -0.6466],
                      [-0.0563, -0.4695]])),
             ('feature_extractor.3.bias', tensor([-0.1247,  0.0803])),
             ('feature_extractor.4.weight', tensor([0.0466, 0.7248])),
             ('feature_extractor.4.bias', tensor([0., 0.])),
             ('feature_extractor.4.running_mean', tensor([0., 0.])),
             ('feature_extractor.4.running_var', tensor([1., 1.])),
             ('feature_extractor.4.num_batches_tracked', tensor(0))])
[ ]: