gpytorch.variational

There are many possible variants of variational/approximate GPs. GPyTorch makes use of 3 composible objects that make it possible to implement most GP approximations:

  • VariationalDistribution, which define the form of the approximate inducing value posterior \(q(\mathbf u)\).
  • VarationalStrategies, which define how to compute \(q(\mathbf f(\mathbf X))\) from \(q(\mathbf u)\).
  • _ApproximateMarginalLogLikelihood, which defines the objective function to learn the approximate posterior (e.g. variational ELBO).

All three of these objects should be used in conjunction with a gpytorch.models.ApproximateGP model.

Variational Strategies

VariationalStrategy objects control how certain aspects of variational inference should be performed. In particular, they define two methods that get used during variational inference:

  • The prior_distribution() method determines how to compute the GP prior distribution of the inducing points, e.g. \(p(u) \sim N(\mu(X_u), K(X_u, X_u))\). Most commonly, this is done simply by calling the user defined GP prior on the inducing point data directly.
  • The forward() method determines how to marginalize out the inducing point function values. Specifically, forward defines how to transform a variational distribution over the inducing point values, \(q(u)\), in to a variational distribution over the function values at specified locations x, \(q(f|x)\), by integrating \(\int p(f|x, u)q(u)du\)

In GPyTorch, we currently support two categories of this latter functionality. In scenarios where the inducing points are learned (or set to be exactly the training data), we apply the derivation in Hensman et al., 2015 to exactly marginalize out the variational distribution. When the inducing points are constrained to a grid, we apply the derivation in Wilson et al., 2016 and exploit a deterministic relationship between \(\mathbf f\) and \(\mathbf u\).

_VariationalStrategy

class gpytorch.variational._VariationalStrategy(model, inducing_points, variational_distribution, learn_inducing_locations=True)[source]

Abstract base class for all Variational Strategies.

forward(x, inducing_points, inducing_values, variational_inducing_covar=None, **kwargs)[source]

The forward() method determines how to marginalize out the inducing point function values. Specifically, forward defines how to transform a variational distribution over the inducing point values, \(q(u)\), in to a variational distribution over the function values at specified locations x, \(q(f|x)\), by integrating \(\int p(f|x, u)q(u)du\)

Parameters:
  • x (torch.Tensor) – Locations \(\mathbf X\) to get the variational posterior of the function values at.
  • inducing_points (torch.Tensor) – Locations \(\mathbf Z\) of the inducing points
  • inducing_values (torch.Tensor) – Samples of the inducing function values \(\mathbf u\) (or the mean of the distribution \(q(\mathbf u)\) if q is a Gaussian.
  • variational_inducing_covar (LazyTensor) – If the distribuiton \(q(\mathbf u)\) is Gaussian, then this variable is the covariance matrix of that Gaussian. Otherwise, it will be None.
Return type:

MultivariateNormal

Returns:

The distribution \(q( \mathbf f(\mathbf X))\)

get_fantasy_model(inputs, targets, mean_module=None, covar_module=None, **kwargs)[source]

Performs the online variational conditioning (OVC) strategy of Maddox et al, ‘21 to return an exact GP model that incorporates the inputs and targets alongside the variational model’s inducing points and targets.

Currently, instead of directly updating the variational parameters (and inducing points), we instead return an ExactGP model rather than an updated variational GP model. This is done primarily for numerical stability.

Unlike the ExactGP’s call for get_fantasy_model, we enable options for mean_module and covar_module that allow specification of the mean / covariance. We expect that either the mean and covariance modules are attributes of the model itself called mean_module and covar_module respectively OR that you pass them into this method explicitly.

Parameters:
  • inputs (torch.Tensor) – (b1 x … x bk x m x d or f x b1 x … x bk x m x d) Locations of fantasy observations.
  • targets (torch.Tensor) – (b1 x … x bk x m or f x b1 x … x bk x m) Labels of fantasy observations.
  • mean_module (torch.nn.Module) – torch module describing the mean function of the GP model. Optional if mean_module is already an attribute of the variational GP.
  • covar_module (torch.nn.Module) – torch module describing the covariance function of the GP model. Optional if covar_module is already an attribute of the variational GP.
Returns:

An ExactGP model with k + m training examples, where the m fantasy examples have been added and all test-time caches have been updated. We assume that there are k inducing points in this variational GP. Note that we return an ExactGP rather than a variational GP.

Return type:

ExactGP

Reference: “Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,”
Maddox, Stanton, Wilson, NeurIPS, ‘21 https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
kl_divergence()[source]

Compute the KL divergence between the variational inducing distribution \(q(\mathbf u)\) and the prior inducing distribution \(p(\mathbf u)\).

Return type:torch.Tensor
prior_distribution

The prior_distribution() method determines how to compute the GP prior distribution of the inducing points, e.g. \(p(u) \sim N(\mu(X_u), K(X_u, X_u))\). Most commonly, this is done simply by calling the user defined GP prior on the inducing point data directly.

Return type:MultivariateNormal
Returns:The distribution \(p( \mathbf u)\)

VariationalStrategy

class gpytorch.variational.VariationalStrategy(model, inducing_points, variational_distribution, learn_inducing_locations=True)[source]

The standard variational strategy, as defined by Hensman et al. (2015). This strategy takes a set of \(m \ll n\) inducing points \(\mathbf Z\) and applies an approximate distribution \(q( \mathbf u)\) over their function values. (Here, we use the common notation \(\mathbf u = f(\mathbf Z)\). The approximate function distribution for any abitrary input \(\mathbf X\) is given by:

\[q( f(\mathbf X) ) = \int p( f(\mathbf X) \mid \mathbf u) q(\mathbf u) \: d\mathbf u\]

This variational strategy uses “whitening” to accelerate the optimization of the variational parameters. See Matthews (2017) for more info.

Parameters:
  • model (ApproximateGP) – Model this strategy is applied to. Typically passed in when the VariationalStrategy is created in the __init__ method of the user defined model.
  • inducing_points (torch.Tensor) – Tensor containing a set of inducing points to use for variational inference.
  • variational_distribution (VariationalDistribution) – A VariationalDistribution object that represents the form of the variational distribution \(q(\mathbf u)\)
  • learn_inducing_locations (bool, optional) – (Default True): Whether or not the inducing point locations \(\mathbf Z\) should be learned (i.e. are they parameters of the model).

BatchDecoupledVariationalStrategy

class gpytorch.variational.BatchDecoupledVariationalStrategy(model, inducing_points, variational_distribution, learn_inducing_locations=True, mean_var_batch_dim=None)[source]

A VariationalStrategy that uses a different set of inducing points for the variational mean and variational covar. It follows the “decoupled” model proposed by Jankowiak et al. (2020) (which is roughly based on the strategies proposed by Cheng et al. (2017).

Let \(\mathbf Z_\mu\) and \(\mathbf Z_\sigma\) be the mean/variance inducing points. The variational distribution for an input \(\mathbf x\) is given by:

\[\begin{split}\begin{align*} \mathbb E[ f(\mathbf x) ] &= \mathbf k_{\mathbf Z_\mu \mathbf x}^\top \mathbf K_{\mathbf Z_\mu \mathbf Z_\mu}^{-1} \mathbf m \\ \text{Var}[ f(\mathbf x) ] &= k_{\mathbf x \mathbf x} - \mathbf k_{\mathbf Z_\sigma \mathbf x}^\top \mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1} \left( \mathbf K_{\mathbf Z_\sigma} - \mathbf S \right) \mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1} \mathbf k_{\mathbf Z_\sigma \mathbf x} \end{align*}\end{split}\]

where \(\mathbf m\) and \(\mathbf S\) are the variational parameters. Unlike the original proposed implementation, \(\mathbf Z_\mu\) and \(\mathbf Z_\sigma\) have the same number of inducing points, which allows us to perform batched operations.

Additionally, you can use a different set of kernel hyperparameters for the mean and the variance function. We recommend using this feature only with the PredictiveLogLikelihood objective function as proposed in “Parametric Gaussian Process Regressors” (Jankowiak et al. (2020)). Use the mean_var_batch_dim to indicate which batch dimension corresponds to the different mean/var kernels.

Note

We recommend using the “right-most” batch dimension (i.e. mean_var_batch_dim=-1) for the dimension that corresponds to the different mean/variance kernel parameters.

Assuming you want b1 many independent GPs, the _VariationalDistribution objects should have a batch shape of b1, and the mean/covar modules of the GP should have a batch shape of b1 x 2. (The 2 corresponds to the mean/variance hyperparameters.)

See also

OrthogonallyDecoupledVariationalStrategy (a variant proposed by Salimbeni et al. (2018) that uses orthogonal projections.)

Parameters:
  • model (ApproximateGP) – Model this strategy is applied to. Typically passed in when the VariationalStrategy is created in the __init__ method of the user defined model.
  • inducing_points (torch.Tensor) – Tensor containing a set of inducing points to use for variational inference.
  • variational_distribution (VariationalDistribution) – A VariationalDistribution object that represents the form of the variational distribution \(q(\mathbf u)\)
  • learn_inducing_locations (bool, optional) – (Default True): Whether or not the inducing point locations \(\mathbf Z\) should be learned (i.e. are they parameters of the model).
  • mean_var_batch_dim (int, optional) – (Default None): Set this parameter (ideally to -1) to indicate which dimension corresponds to different kernel hyperparameters for the mean/variance functions.
Example (different hypers for mean/variance):
>>> class MeanFieldDecoupledModel(gpytorch.models.ApproximateGP):
>>>     '''
>>>     A batch of 3 independent MeanFieldDecoupled PPGPR models.
>>>     '''
>>>     def __init__(self, inducing_points):
>>>         # The variational parameters have a batch_shape of [3]
>>>         variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(
>>>             inducing_points.size(-1), batch_shape=torch.Size([3]),
>>>         )
>>>         variational_strategy = gpytorch.variational.BatchDecoupledVariationalStrategy(
>>>             self, inducing_points, variational_distribution, learn_inducing_locations=True,
>>>             mean_var_batch_dim=-1
>>>         )
>>>
>>>         # The mean/covar modules have a batch_shape of [3, 2]
>>>         # where the last batch dim corresponds to the mean & variance hyperparameters
>>>         super().__init__(variational_strategy)
>>>         self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([3, 2]))
>>>         self.covar_module = gpytorch.kernels.ScaleKernel(
>>>             gpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 2])),
>>>             batch_shape=torch.Size([3, 2]),
>>>         )
Example (shared hypers for mean/variance):
>>> class MeanFieldDecoupledModel(gpytorch.models.ApproximateGP):
>>>     '''
>>>     A batch of 3 independent MeanFieldDecoupled PPGPR models.
>>>     '''
>>>     def __init__(self, inducing_points):
>>>         # The variational parameters have a batch_shape of [3]
>>>         variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(
>>>             inducing_points.size(-1), batch_shape=torch.Size([3]),
>>>         )
>>>         variational_strategy = gpytorch.variational.BatchDecoupledVariationalStrategy(
>>>             self, inducing_points, variational_distribution, learn_inducing_locations=True,
>>>         )
>>>
>>>         # The mean/covar modules have a batch_shape of [3, 1]
>>>         # where the singleton dimension corresponds to the shared mean/variance hyperparameters
>>>         super().__init__(variational_strategy)
>>>         self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([3, 1]))
>>>         self.covar_module = gpytorch.kernels.ScaleKernel(
>>>             gpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 1])),
>>>             batch_shape=torch.Size([3, 1]),
>>>         )

CiqVariationalStrategy

class gpytorch.variational.CiqVariationalStrategy(model, inducing_points, variational_distribution, learn_inducing_locations=True)[source]

Similar to VariationalStrategy, except the whitening operation is performed using Contour Integral Quadrature rather than Cholesky (see Pleiss et al. (2020) for more info). See the CIQ-SVGP tutorial for an example.

Contour Integral Quadrature uses iterative matrix-vector multiplication to approximate the \(\mathbf K_{\mathbf Z \mathbf Z}^{-1/2}\) matrix used for the whitening operation. This can be more efficient than the standard variational strategy for large numbers of inducing points (e.g. \(M > 1000\)) or when the inducing points have structure (e.g. they lie on an evenly-spaced grid).

Note

It is recommended that this object is used in conjunction with NaturalVariationalDistribution and natural gradient descent.

Parameters:
  • model (ApproximateGP) – Model this strategy is applied to. Typically passed in when the VariationalStrategy is created in the __init__ method of the user defined model.
  • inducing_points (torch.Tensor) – Tensor containing a set of inducing points to use for variational inference.
  • variational_distribution (VariationalDistribution) – A VariationalDistribution object that represents the form of the variational distribution \(q(\mathbf u)\)
  • learn_inducing_locations (bool, optional) – (Default True): Whether or not the inducing point locations \(\mathbf Z\) should be learned (i.e. are they parameters of the model).
kl_divergence()[source]

Compute the KL divergence between the variational inducing distribution \(q(\mathbf u)\) and the prior inducing distribution \(p(\mathbf u)\).

Return type:torch.Tensor

NNVariationalStrategy

class gpytorch.variational.NNVariationalStrategy(model, inducing_points, variational_distribution, k, training_batch_size)[source]

This strategy sets all inducing point locations to observed inputs, and employs a \(k\)-nearest-neighbor approximation. It was introduced as the Variational Nearest Neighbor Gaussian Processes (VNNGP) in Wu et al (2022). See the VNNGP tutorial for an example.

VNNGP assumes a k-nearest-neighbor generative process for inducing points \(\mathbf u\), \(\mathbf q(\mathbf u) = \prod_{j=1}^M q(u_j | \mathbf u_{n(j)})\) where \(n(j)\) denotes the indices of \(k\) nearest neighbors for \(u_j\) among \(u_1, \cdots, u_{j-1}\). For any test observation \(\mathbf f\), VNNGP makes predictive inference conditioned on its \(k\) nearest inducing points \(\mathbf u_{n(f)}\), i.e. \(p(f|\mathbf u_{n(f)})\).

VNNGP’s objective factorizes over inducing points and observations, making stochastic optimization over both immediately available. After a one-time cost of computing the \(k\)-nearest neighbor structure, the training and inference complexity is \(O(k^3)\). Since VNNGP uses observations as inducing points, it is a user choice to either (1) use the same mini-batch of inducing points and observations (recommended), or (2) use different mini-batches of inducing points and observations. See the VNNGP tutorial for implementation and comparison.

Note

The current implementation only supports MeanFieldVariationalDistribution.

We recommend installing the faiss library (requiring separate package installment) for nearest neighbor search, which is significantly faster than the scikit-learn nearest neighbor search. GPyTorch will automatically use faiss if it is installed, but will revert to scikit-learn otherwise.

Different inducing point orderings will produce in different nearest neighbor approximations.

Parameters:
  • model (ApproximateGP) – Model this strategy is applied to. Typically passed in when the VariationalStrategy is created in the __init__ method of the user defined model.
  • inducing_points (torch.Tensor) – Tensor containing a set of inducing points to use for variational inference.
  • variational_distribution (VariationalDistribution) – A VariationalDistribution object that represents the form of the variational distribution \(q(\mathbf u)\)
  • learn_inducing_locations (bool, optional) – (Default True): Whether or not the inducing point locations \(\mathbf Z\) should be learned (i.e. are they parameters of the model).

OrthogonallyDecoupledVariationalStrategy

class gpytorch.variational.OrthogonallyDecoupledVariationalStrategy(covar_variational_strategy, inducing_points, variational_distribution)[source]

Implements orthogonally decoupled VGPs as defined in Salimbeni et al. (2018). This variational strategy uses a different set of inducing points for the mean and covariance functions. The idea is to use more inducing points for the (computationally efficient) mean and fewer inducing points for the (computationally expensive) covaraince.

This variational strategy defines the inducing points/_VariationalDistribution for the mean function. It then wraps a different _VariationalStrategy which defines the covariance inducing points.

Parameters:
  • covar_variational_strategy (_VariationalStrategy) – The variational strategy for the covariance term.
  • inducing_points (torch.Tensor) – Tensor containing a set of inducing points to use for variational inference.
  • variational_distribution (VariationalDistribution) – A VariationalDistribution object that represents the form of the variational distribution \(q(\mathbf u)\)

Example

>>> mean_inducing_points = torch.randn(1000, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)
>>> covar_inducing_points = torch.randn(100, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)
>>>
>>> covar_variational_strategy = gpytorch.variational.VariationalStrategy(
>>>     model, covar_inducing_points,
>>>     gpytorch.variational.CholeskyVariationalDistribution(covar_inducing_points.size(-2)),
>>>     learn_inducing_locations=True
>>> )
>>>
>>> variational_strategy = gpytorch.variational.OrthogonallyDecoupledVariationalStrategy(
>>>     covar_variational_strategy, mean_inducing_points,
>>>     gpytorch.variational.DeltaVariationalDistribution(mean_inducing_points.size(-2)),
>>> )

UnwhitenedVariationalStrategy

class gpytorch.variational.UnwhitenedVariationalStrategy(model, inducing_points, variational_distribution, learn_inducing_locations=True)[source]

Similar to VariationalStrategy, but does not perform the whitening operation. In almost all cases VariationalStrategy is preferable, with a few exceptions:

  • When the inducing points are exactly equal to the training points (i.e. \(\mathbf Z = \mathbf X\)). Unwhitened models are faster in this case.
  • When the number of inducing points is very large (e.g. >2000). Unwhitened models can use CG for faster computation.
Parameters:
  • model (ApproximateGP) – Model this strategy is applied to. Typically passed in when the VariationalStrategy is created in the __init__ method of the user defined model.
  • inducing_points (torch.Tensor) – Tensor containing a set of inducing points to use for variational inference.
  • variational_distribution (VariationalDistribution) – A VariationalDistribution object that represents the form of the variational distribution \(q(\mathbf u)\)
  • learn_inducing_points (bool) – (optional, default True): Whether or not the inducing point locations \(\mathbf Z\) should be learned (i.e. are they parameters of the model).

GridInterpolationVariationalStrategy

class gpytorch.variational.GridInterpolationVariationalStrategy(model, grid_size, grid_bounds, variational_distribution)[source]

This strategy constrains the inducing points to a grid and applies a deterministic relationship between \(\mathbf f\) and \(\mathbf u\). It was introduced by Wilson et al. (2016).

Here, the inducing points are not learned. Instead, the strategy automatically creates inducing points based on a set of grid sizes and grid bounds.

Parameters:
  • model (ApproximateGP) – Model this strategy is applied to. Typically passed in when the VariationalStrategy is created in the __init__ method of the user defined model.
  • grid_size (int) – Size of the grid
  • grid_bounds (list) – Bounds of each dimension of the grid (should be a list of (float, float) tuples)
  • variational_distribution (VariationalDistribution) – A VariationalDistribution object that represents the form of the variational distribution \(q(\mathbf u)\)

Variational Strategies for Multi-Output Functions

These are special _VariationalStrategy objects that return MultitaskMultivariateNormal distributions. Each of these objects acts on a batch of approximate GPs.

LMCVariationalStrategy

class gpytorch.variational.LMCVariationalStrategy(base_variational_strategy, num_tasks, num_latents=1, latent_dim=-1)[source]

LMCVariationalStrategy is an implementation of the “Linear Model of Coregionalization” for multitask GPs. This model assumes that there are \(Q\) latent functions \(\mathbf g(\cdot) = [g^{(1)}(\cdot), \ldots, g^{(q)}(\cdot)]\), each of which is modelled by a GP. The output functions (tasks) are linear combination of the latent functions:

\[f_{\text{task } i}( \mathbf x) = \sum_{q=1}^Q a_i^{(q)} g^{(q)} ( \mathbf x )\]

LMCVariationalStrategy wraps an existing VariationalStrategy. The output will either be a MultitaskMultivariateNormal distribution (if we wish to evaluate all tasks for each input) or a MultivariateNormal (if we wish to evaluate a single task for each input).

The base variational strategy is assumed to operate on a multi-batch of GPs, where one of the batch dimensions corresponds to the latent function dimension.

Note

The batch shape of the base VariationalStrategy does not necessarily have to correspond to the batch shape of the underlying GP objects.

For example, if the base variational strategy has a batch shape of [3] (corresponding to 3 latent functions), the GP kernel object could have a batch shape of [3] or no batch shape. This would correspond to each of the latent functions having different kernels or the same kernel, respectivly.

Example

>>> class LMCMultitaskGP(gpytorch.models.ApproximateGP):
>>>     '''
>>>     3 latent functions
>>>     5 output dimensions (tasks)
>>>     '''
>>>     def __init__(self):
>>>         # Each latent function shares the same inducing points
>>>         # We'll have 32 inducing points, and let's assume the input dimensionality is 2
>>>         inducing_points = torch.randn(32, 2)
>>>
>>>         # The variational parameters have a batch_shape of [3] - for 3 latent functions
>>>         variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(
>>>             inducing_points.size(-1), batch_shape=torch.Size([3]),
>>>         )
>>>         variational_strategy = gpytorch.variational.LMCVariationalStrategy(
>>>             gpytorch.variational.VariationalStrategy(
>>>                 inducing_points, variational_distribution, learn_inducing_locations=True,
>>>             ),
>>>             num_tasks=5,
>>>             num_latents=3,
>>>             latent_dim=0,
>>>         )
>>>
>>>         # Each latent function has its own mean/kernel function
>>>         super().__init__(variational_strategy)
>>>         self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([3]))
>>>         self.covar_module = gpytorch.kernels.ScaleKernel(
>>>             gpytorch.kernels.RBFKernel(batch_shape=torch.Size([3])),
>>>             batch_shape=torch.Size([3]),
>>>         )
>>>
Parameters:
  • base_variational_strategy (VariationalStrategy) – Base variational strategy
  • num_tasks (int) – The total number of tasks (output functions)
  • num_latents (int) – The total number of latent functions in each group
  • latent_dim (int < 0) – (Default: -1) Which batch dimension corresponds to the latent function batch. Must be negative indexed
__call__(x, task_indices=None, prior=False, **kwargs)[source]

Computes the variational (or prior) distribution \(q( \mathbf f \mid \mathbf X)\) (or \(p( \mathbf f \mid \mathbf X)\)). There are two modes:

  1. Compute all tasks for all inputs. If this is the case, the task_indices attribute should be None. The return type will be a (… x N x num_tasks) MultitaskMultivariateNormal.
  2. Compute one task per inputs. If this is the case, the (… x N) task_indices tensor should contain the indices of each input’s assigned task. The return type will be a (… x N) MultivariateNormal.
Parameters:
  • x (torch.Tensor (.. x N x D)) – Input locations to evaluate variational strategy
  • task_indices (torch.Tensor (.. x N), optional) – (Default: None) Task index associated with each input. If this is not provided, then the returned distribution evaluates every input on every task (returns MultitaskMultivariateNormal). If this is provided, then the returned distribution evaluates each input only on its assigned task. (returns MultivariateNormal).
  • prior (bool) – (Default: False) If False, returns the variational distribution \(q( \mathbf f \mid \mathbf X)\). If True, returns the prior distribution \(p( \mathbf f \mid \mathbf X)\).
Returns:

\(q( \mathbf f \mid \mathbf X)\) (or the prior), either for all tasks (if task_indices == None) or for a specific task (if task_indices != None).

Return type:

MultitaskMultivariateNormal (.. x N x num_tasks) or MultivariateNormal (.. x N)

IndependentMultitaskVariationalStrategy

class gpytorch.variational.IndependentMultitaskVariationalStrategy(base_variational_strategy, num_tasks, task_dim=-1)[source]

IndependentMultitaskVariationalStrategy wraps an existing VariationalStrategy to produce vector-valued (multi-task) output distributions. Each task will be independent of one another.

The output will either be a MultitaskMultivariateNormal distribution (if we wish to evaluate all tasks for each input) or a MultivariateNormal (if we wish to evaluate a single task for each input).

The base variational strategy is assumed to operate on a batch of GPs. One of the batch dimensions corresponds to the multiple tasks.

Parameters:
  • base_variational_strategy (VariationalStrategy) – Base variational strategy
  • num_tasks (int) – Number of tasks. Should correspond to the batch size of task_dim.
  • task_dim (int) – (Default: -1) Which batch dimension is the task dimension
__call__(x, task_indices=None, prior=False, **kwargs)[source]

See LMCVariationalStrategy.

Variational Distributions

VariationalDistribution objects represent the variational distribution \(q(\mathbf u)\) over a set of inducing points for GPs. Typically the distributions are some sort of parameterization of a multivariate normal distributions.

_VariationalDistribution

class gpytorch.variational._VariationalDistribution(num_inducing_points, batch_shape=torch.Size([]), mean_init_std=0.001)[source]

Abstract base class for all Variational Distributions.

Variables:
  • dtype (torch.dtype) – The dtype of the VariationalDistribution parameters
  • device (torch.dtype) – The device of the VariationalDistribution parameters
forward()[source]

Constructs and returns the variational distribution

Return type:MultivariateNormal
Returns:The distribution \(q(\mathbf u)\)
initialize_variational_distribution(prior_dist)[source]

Method for initializing the variational distribution, based on the prior distribution.

Parameters:prior_dist (Distribution) – The prior distribution \(p(\mathbf u)\).
shape() → torch.Size[source]

Event + batch shape of VariationalDistribution object :rtype: torch.Size

CholeskyVariationalDistribution

class gpytorch.variational.CholeskyVariationalDistribution(num_inducing_points, batch_shape=torch.Size([]), mean_init_std=0.001, **kwargs)[source]

A _VariationalDistribution that is defined to be a multivariate normal distribution with a full covariance matrix.

The most common way this distribution is defined is to parameterize it in terms of a mean vector and a covariance matrix. In order to ensure that the covariance matrix remains positive definite, we only consider the lower triangle.

Parameters:
  • num_inducing_points (int) – Size of the variational distribution. This implies that the variational mean should be this size, and the variational covariance matrix should have this many rows and columns.
  • batch_shape (torch.Size, optional) – Specifies an optional batch size for the variational parameters. This is useful for example when doing additive variational inference.
  • mean_init_std (float) – (Default: 1e-3) Standard deviation of gaussian noise to add to the mean initialization.

DeltaVariationalDistribution

class gpytorch.variational.DeltaVariationalDistribution(num_inducing_points, batch_shape=torch.Size([]), mean_init_std=0.001, **kwargs)[source]

This _VariationalDistribution object replaces a variational distribution with a single particle. It is equivalent to doing MAP inference.

Parameters:
  • num_inducing_points (int) – Size of the variational distribution. This implies that the variational mean should be this size, and the variational covariance matrix should have this many rows and columns.
  • batch_shape (torch.Size, optional) – Specifies an optional batch size for the variational parameters. This is useful for example when doing additive variational inference.
  • mean_init_std (float) – (Default: 1e-3) Standard deviation of gaussian noise to add to the mean initialization.

MeanFieldVariationalDistribution

class gpytorch.variational.MeanFieldVariationalDistribution(num_inducing_points, batch_shape=torch.Size([]), mean_init_std=0.001, **kwargs)[source]

A _VariationalDistribution that is defined to be a multivariate normal distribution with a diagonal covariance matrix. This will not be as flexible/expressive as a CholeskyVariationalDistribution.

Parameters:
  • num_inducing_points (int) – Size of the variational distribution. This implies that the variational mean should be this size, and the variational covariance matrix should have this many rows and columns.
  • batch_shape (torch.Size, optional) – Specifies an optional batch size for the variational parameters. This is useful for example when doing additive variational inference.
  • mean_init_std (float) – (Default: 1e-3) Standard deviation of gaussian noise to add to the mean initialization.

Variational Distributions for Natural Gradient Descent

Special _VariationalDistribution objects designed specifically for use with natural gradient descent techniques. See the natural gradient descent tutorial for examples using these objects.

If the variational distribution is defined by \(\mathcal{N}(\mathbf m, \mathbf S)\), then a NaturalVariationalDistribution uses the parameterization:

\[\begin{split}\begin{align*} \boldsymbol \theta_\text{vec} &= \mathbf S^{-1} \mathbf m \\ \boldsymbol \Theta_\text{mat} &= -\frac{1}{2} \mathbf S^{-1}. \end{align*}\end{split}\]

The gradients with respect to the variational parameters calculated by this class are instead the natural gradients. Thus, optimising its parameters using gradient descent (SGDOptimizer) becomes natural gradient descent (see e.g. Salimbeni et al., 2018).

GPyTorch offers several _NaturalVariationalDistribution classes, each of which uses a different representation of the natural parameters. The different parameterizations trade off speed and stability.

Note

Natural gradient descent is very stable with variational regression, and fast: if the hyperparameters are fixed, the variational parameters converge in 1 iteration. However, it can be unstable with non-conjugate likelihoods and alternative objective functions.

NaturalVariationalDistribution

class gpytorch.variational.NaturalVariationalDistribution(num_inducing_points, batch_shape=torch.Size([]), mean_init_std=0.001, **kwargs)[source]

A multivariate normal _VariationalDistribution, parameterized by natural parameters.

Note

The NaturalVariationalDistribution can only be used with gpytorch.optim.NGD, or other optimizers that follow exactly the gradient direction. Failure to do so will cause the natural matrix \(\mathbf \Theta_\text{mat}\) to stop being positive definite, and a RuntimeError will be raised.

See also

The natural gradient descent tutorial for use instructions.

The TrilNaturalVariationalDistribution for a more numerically stable parameterization, at the cost of needing more iterations to make variational regression converge.

Parameters:
  • num_inducing_points (int) – Size of the variational distribution. This implies that the variational mean should be this size, and the variational covariance matrix should have this many rows and columns.
  • batch_shape (torch.Size, optional) – Specifies an optional batch size for the variational parameters. This is useful for example when doing additive variational inference.
  • mean_init_std (float) – (Default: 1e-3) Standard deviation of gaussian noise to add to the mean initialization.

TrilNaturalVariationalDistribution

class gpytorch.variational.TrilNaturalVariationalDistribution(num_inducing_points, batch_shape=torch.Size([]), mean_init_std=0.001, **kwargs)[source]

A multivariate normal _VariationalDistribution, parameterized by the natural vector, and a triangular decomposition of the natural matrix (which is not the Cholesky).

Note

The TrilNaturalVariationalDistribution should only be used with gpytorch.optim.NGD, or other optimizers that follow exactly the gradient direction.

See also

The natural gradient descent tutorial for use instructions.

The NaturalVariationalDistribution, which needs less iterations to make variational regression converge, at the cost of introducing numerical instability.

Note

The relationship of the parameter \(\mathbf \Theta_\text{tril_mat}\) to the natural parameter \(\mathbf \Theta_\text{mat}\) from NaturalVariationalDistribution is \(\mathbf \Theta_\text{mat} = -1/2 {\mathbf \Theta_\text{tril_mat}}^T {\mathbf \Theta_\text{tril_mat}}\). Note that this is not the form of the Cholesky decomposition of \(\boldsymbol \Theta_\text{mat}\).

Parameters:
  • num_inducing_points (int) – Size of the variational distribution. This implies that the variational mean should be this size, and the variational covariance matrix should have this many rows and columns.
  • batch_shape (torch.Size, optional) – Specifies an optional batch size for the variational parameters. This is useful for example when doing additive variational inference.
  • mean_init_std (float) – (Default: 1e-3) Standard deviation of gaussian noise to add to the mean initialization.