gpytorch.models

Models for Exact GP Inference

ExactGP

class gpytorch.models.ExactGP(train_inputs, train_targets, likelihood)[source]

The base class for any Gaussian process latent function to be used in conjunction with exact inference.

Parameters:
  • train_inputs (torch.Tensor) – (size n x d) The training features \(\mathbf X\).
  • train_targets (torch.Tensor) – (size n) The training targets \(\mathbf y\).
  • likelihood (GaussianLikelihood) – The Gaussian likelihood that defines the observational distribution. Since we’re using exact inference, the likelihood must be Gaussian.

The forward() function should describe how to compute the prior latent distribution on a given input. Typically, this will involve a mean and kernel function. The result must be a MultivariateNormal.

Calling this model will return the posterior of the latent Gaussian process when conditioned on the training data. The output will be a MultivariateNormal.

Example

>>> class MyGP(gpytorch.models.ExactGP):
>>>     def __init__(self, train_x, train_y, likelihood):
>>>         super().__init__(train_x, train_y, likelihood)
>>>         self.mean_module = gpytorch.means.ZeroMean()
>>>         self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
>>>
>>>     def forward(self, x):
>>>         mean = self.mean_module(x)
>>>         covar = self.covar_module(x)
>>>         return gpytorch.distributions.MultivariateNormal(mean, covar)
>>>
>>> # train_x = ...; train_y = ...
>>> likelihood = gpytorch.likelihoods.GaussianLikelihood()
>>> model = MyGP(train_x, train_y, likelihood)
>>>
>>> # test_x = ...;
>>> model(test_x)  # Returns the GP latent function at test_x
>>> likelihood(model(test_x))  # Returns the (approximate) predictive posterior distribution at test_x
get_fantasy_model(inputs, targets, **kwargs)[source]

Returns a new GP model that incorporates the specified inputs and targets as new training data.

Using this method is more efficient than updating with set_train_data when the number of inputs is relatively small, because any computed test-time caches will be updated in linear time rather than computed from scratch.

Note

If targets is a batch (e.g. b x m), then the GP returned from this method will be a batch mode GP. If inputs is of the same (or lesser) dimension as targets, then it is assumed that the fantasy points are the same for each target batch.

Parameters:
  • inputs (torch.Tensor) – (b1 x … x bk x m x d or f x b1 x … x bk x m x d) Locations of fantasy observations.
  • targets (torch.Tensor) – (b1 x … x bk x m or f x b1 x … x bk x m) Labels of fantasy observations.
Returns:

An ExactGP model with n + m training examples, where the m fantasy examples have been added and all test-time caches have been updated.

Return type:

ExactGP

local_load_samples(samples_dict, memo, prefix)[source]

Replace the model’s learned hyperparameters with samples from a posterior distribution.

set_train_data(inputs=None, targets=None, strict=True)[source]

Set training data (does not re-fit model hyper-parameters).

Parameters:
  • inputs (torch.Tensor) – The new training inputs.
  • targets (torch.Tensor) – The new training targets.
  • strict (bool) – (default True) If True, the new inputs and targets must have the same shape, dtype, and device as the current inputs and targets. Otherwise, any shape/dtype/device are allowed.

Models for Approximate GP Inference

ApproximateGP

class gpytorch.models.ApproximateGP(variational_strategy)[source]

The base class for any Gaussian process latent function to be used in conjunction with approximate inference (typically stochastic variational inference). This base class can be used to implement most inducing point methods where the variational parameters are learned directly.

Parameters:variational_strategy (_VariationalStrategy) – The strategy that determines how the model marginalizes over the variational distribution (over inducing points) to produce the approximate posterior distribution (over data)

The forward() function should describe how to compute the prior latent distribution on a given input. Typically, this will involve a mean and kernel function. The result must be a MultivariateNormal.

Example

>>> class MyVariationalGP(gpytorch.models.PyroGP):
>>>     def __init__(self, variational_strategy):
>>>         super().__init__(variational_strategy)
>>>         self.mean_module = gpytorch.means.ZeroMean()
>>>         self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
>>>
>>>     def forward(self, x):
>>>         mean = self.mean_module(x)
>>>         covar = self.covar_module(x)
>>>         return gpytorch.distributions.MultivariateNormal(mean, covar)
>>>
>>> # variational_strategy = ...
>>> model = MyVariationalGP(variational_strategy)
>>> likelihood = gpytorch.likelihoods.GaussianLikelihood()
>>>
>>> # optimization loop for variational parameters...
>>>
>>> # test_x = ...;
>>> model(test_x)  # Returns the approximate GP latent function at test_x
>>> likelihood(model(test_x))  # Returns the (approximate) predictive posterior distribution at test_x
get_fantasy_model(inputs, targets, **kwargs)[source]

Returns a new GP model that incorporates the specified inputs and targets as new training data using online variational conditioning (OVC).

This function first casts the inducing points and variational parameters into pseudo-points before returning an equivalent ExactGP model with a specialized likelihood.

Note

If targets is a batch (e.g. b x m), then the GP returned from this method will be a batch mode GP. If inputs is of the same (or lesser) dimension as targets, then it is assumed that the fantasy points are the same for each target batch.

Parameters:
  • inputs (torch.Tensor) – (b1 x … x bk x m x d or f x b1 x … x bk x m x d) Locations of fantasy observations.
  • targets (torch.Tensor) – (b1 x … x bk x m or f x b1 x … x bk x m) Labels of fantasy observations.
Returns:

An ExactGP model with n + m training examples, where the m fantasy examples have been added and all test-time caches have been updated.

Return type:

ExactGP

Reference: “Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,”
Maddox, Stanton, Wilson, NeurIPS, ‘21 https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
pyro_guide(input, beta=1.0, name_prefix='')[source]

(For Pyro integration only). The component of a pyro.guide that corresponds to drawing samples from the latent GP function.

Parameters:
  • input (torch.Tensor) – The inputs \(\mathbf X\).
  • beta (float) – (default=1.) How much to scale the \(\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]\) term by.
  • name_prefix (str) – (default=””) A name prefix to prepend to pyro sample sites.
pyro_model(input, beta=1.0, name_prefix='')[source]

(For Pyro integration only). The component of a pyro.model that corresponds to drawing samples from the latent GP function.

Parameters:
  • input (torch.Tensor) – The inputs \(\mathbf X\).
  • beta (float) – (default=1.) How much to scale the \(\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]\) term by.
  • name_prefix (str) – (default=””) A name prefix to prepend to pyro sample sites.
Returns:

samples from \(q(\mathbf f)\)

Return type:

torch.Tensor

Models for Deep GPs

deep_gps.DeepGP

class gpytorch.models.deep_gps.DeepGP[source]

A container module to build a DeepGP. This module should contain DeepGPLayer modules, and can also contain other modules as well.

deep_gps.DeepGPLayer

class gpytorch.models.deep_gps.DeepGPLayer(variational_strategy, input_dims, output_dims)[source]

Represents a layer in a deep GP where inference is performed via the doubly stochastic method of Salimbeni et al., 2017. Upon calling, instead of returning a variational distribution q(f), returns samples from the variational distribution.

See the documentation for __call__ below for more details below. Note that the behavior of __call__ will change to be much more elegant with multiple batch dimensions; however, the interface doesn’t really change.

Parameters:
  • variational_strategy (VariationalStrategy) – Strategy for changing q(u) -> q(f) (see other VI docs)
  • input_dims` (int) – Dimensionality of input data expected by each GP
  • output_dims (int) – (default None) Number of GPs in this layer, equivalent to output dimensionality. If set to None, then the output dimension will be squashed.

Forward data through this hidden GP layer. The output is a MultitaskMultivariateNormal distribution (or MultivariateNormal distribution is output_dims=None).

If the input is >=2 dimensional Tensor (e.g. n x d), we pass the input through each hidden GP, resulting in a n x h multitask Gaussian distribution (where all of the h tasks represent an output dimension and are independent from one another). We then draw s samples from these Gaussians, resulting in a s x n x h MultitaskMultivariateNormal distribution.

If the input is a >=3 dimensional Tensor, and the are_samples=True kwarg is set, then we assume that the outermost batch dimension is a samples dimension. The output will have the same number of samples. For example, a s x b x n x d input will result in a s x b x n x h MultitaskMultivariateNormal distribution.

The goal of these last two points is that if you have a tensor x that is n x d, then

>>> hidden_gp2(hidden_gp(x))

will just work, and return a tensor of size s x n x h2, where h2 is the output dimensionality of hidden_gp2. In this way, hidden GP layers are easily composable.

Gaussian Process Latent Variable Models (GPLVM)

gplvm.BayesianGPLVM

class gpytorch.models.gplvm.BayesianGPLVM(X, variational_strategy)[source]

The Gaussian Process Latent Variable Model (GPLVM) class for unsupervised learning. The class supports

  1. Point estimates for latent X when prior_x = None
  2. MAP Inference for X when prior_x is not None and inference == ‘map’
  3. Gaussian variational distribution q(X) when prior_x is not None and inference == ‘variational’

See also

The GPLVM tutorial for use instructions.

Parameters:
  • X (LatentVariable) – An instance of a sub-class of the LatentVariable class. One of, PointLatentVariable, MAPLatentVariable, or VariationalLatentVariable, to facilitate inference with 1, 2, or 3 respectively.
  • variational_strategy (_VariationalStrategy) – The strategy that determines how the model marginalizes over the variational distribution (over inducing points) to produce the approximate posterior distribution (over data)

gplvm.PointLatentVariable

class gpytorch.models.gplvm.PointLatentVariable(n, latent_dim, X_init)[source]

This class is used for GPLVM models to recover a MLE estimate of the latent variable \(\mathbf X\).

Parameters:
  • n (int) – Size of the latent space.
  • latent_dim (int) – Dimensionality of latent space.
  • X_init (torch.Tensor) – initialization for the point estimate of \(\mathbf X\)

gplvm.MAPLatentVariable

class gpytorch.models.gplvm.MAPLatentVariable(n, latent_dim, X_init, prior_x)[source]

This class is used for GPLVM models to recover a MAP estimate of the latent variable \(\mathbf X\), based on some supplied prior.

Parameters:
  • n (int) – Size of the latent space.
  • latent_dim (int) – Dimensionality of latent space.
  • X_init (torch.Tensor) – initialization for the point estimate of \(\mathbf X\)
  • prior_x (Prior) – prior for \(\mathbf X\)

gplvm.VariationalLatentVariable

class gpytorch.models.gplvm.VariationalLatentVariable(n, data_dim, latent_dim, X_init, prior_x)[source]

This class is used for GPLVM models to recover a variational approximation of the latent variable \(\mathbf X\). The variational approximation will be an isotropic Gaussian distribution.

Parameters:
  • n (int) – Size of the latent space.
  • data_dim (int) – Dimensionality of the \(\mathbf Y\) values.
  • latent_dim (int) – Dimensionality of latent space.
  • X_init (torch.Tensor) – initialization for the point estimate of \(\mathbf X\)
  • prior_x (Prior) – prior for \(\mathbf X\)

Models for integrating with Pyro

PyroGP

class gpytorch.models.PyroGP(variational_strategy, likelihood, num_data, name_prefix='', beta=1.0)[source]

A ApproximateGP designed to work with Pyro.

This module makes it possible to include GP models with more complex probablistic models, or to use likelihood functions with additional variational/approximate distributions.

The parameters of these models are learned using Pyro’s inference tools, unlike other models that optimize models with respect to a MarginalLogLikelihood. See the Pyro examples for detailed examples.

Parameters:
  • variational_strategy (VariationalStrategy) – The variational strategy that defines the variational distribution and the marginalization strategy.
  • likelihood (Likelihood) – The likelihood for the model
  • num_data (int) – The total number of training data points (necessary for SGD)
  • name_prefix (str, optional) – A prefix to put in front of pyro sample/plate sites
  • beta (float - default 1.) – A multiplicative factor for the KL divergence term. Setting it to 1 (default) recovers true variational inference (as derived in Scalable Variational Gaussian Process Classification). Setting it to anything less than 1 reduces the regularization effect of the model (similarly to what was proposed in the beta-VAE paper).

Example

>>> class MyVariationalGP(gpytorch.models.PyroGP):
>>>     # implementation
>>>
>>> # variational_strategy = ...
>>> likelihood = gpytorch.likelihoods.GaussianLikelihood()
>>> model = MyVariationalGP(variational_strategy, likelihood, train_y.size())
>>>
>>> optimizer = pyro.optim.Adam({"lr": 0.01})
>>> elbo = pyro.infer.Trace_ELBO(num_particles=64, vectorize_particles=True)
>>> svi = pyro.infer.SVI(model.model, model.guide, optimizer, elbo)
>>>
>>> # Optimize variational parameters
>>> for _ in range(n_iter):
>>>    loss = svi.step(train_x, train_y)
guide(input, target, *args, **kwargs)[source]

Guide function for Pyro inference. Includes the guide for the GP’s likelihood function as well.

Parameters:
  • input (torch.Tensor) – \(\mathbf X\) The input values values
  • target (torch.Tensor) – \(\mathbf y\) The target values
  • args – Additional arguments passed to the likelihood’s forward function.
  • kwargs – Additional keyword arguments passed to the likelihood’s forward function.
model(input, target, *args, **kwargs)[source]

Model function for Pyro inference. Includes the model for the GP’s likelihood function as well.

Parameters:
  • input (torch.Tensor) – \(\mathbf X\) The input values values
  • target (torch.Tensor) – \(\mathbf y\) The target values
  • args – Additional arguments passed to the likelihood’s forward function.
  • kwargs – Additional keyword arguments passed to the likelihood’s forward function.