gpytorch.mlls¶
These are modules to compute (or approximate/bound) the marginal log likelihood (MLL) of the GP model when applied to data. I.e., given a GP \(f \sim \mathcal{GP}(\mu, K)\), and data \(\mathbf X, \mathbf y\), these modules compute/approximate
This is computed exactly when the GP inference is computed exactly (e.g. regression w/ a Gaussian likelihood). It is approximated/bounded for GP models that use approximate inference.
These models are typically used as the “loss” functions for GP models (though note that the output of these functions must be negated for optimization).
Exact GP Inference¶
These are MLLs for use with ExactGP
modules. They compute the MLL exactly.
ExactMarginalLogLikelihood¶
- class gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)[source]¶
The exact marginal log likelihood (MLL) for an exact Gaussian process with a Gaussian likelihood.
Note
This module will not work with anything other than a
GaussianLikelihood
and aExactGP
. It also cannot be used in conjunction with stochastic optimization.- Parameters:
likelihood (GaussianLikelihood) – The Gaussian likelihood for the model
model (ExactGP) – The exact GP model
Example
>>> # model is a gpytorch.models.ExactGP >>> # likelihood is a gpytorch.likelihoods.Likelihood >>> mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) >>> >>> output = model(train_x) >>> loss = -mll(output, train_y) >>> loss.backward()
- forward(function_dist, target, *params, **kwargs)[source]¶
Computes the MLL given \(p(\mathbf f)\) and \(\mathbf y\).
- Parameters:
function_dist (MultivariateNormal) – \(p(\mathbf f)\) the outputs of the latent function (the
gpytorch.models.ExactGP
)target (torch.Tensor) – \(\mathbf y\) The target values
- Return type:
- Returns:
Exact MLL. Output shape corresponds to batch shape of the model/input data.
LeaveOneOutPseudoLikelihood¶
- class gpytorch.mlls.LeaveOneOutPseudoLikelihood(likelihood, model)[source]¶
The leave one out cross-validation (LOO-CV) likelihood from RW 5.4.2 for an exact Gaussian process with a Gaussian likelihood. This offers an alternative to the exact marginal log likelihood where we instead maximize the sum of the leave one out log probabilities \(\log p(y_i | X, y_{-i}, \theta)\).
Naively, this will be O(n^4) with Cholesky as we need to compute n Cholesky factorizations. Fortunately, given the Cholesky factorization of the full kernel matrix (without any points removed), we can compute both the mean and variance of each removed point via a bordered system formulation making the total complexity O(n^3).
The LOO-CV approach can be more robust against model mis-specification as it gives an estimate for the (log) predictive probability, whether or not the assumptions of the model is fulfilled.
Note
This module will not work with anything other than a
GaussianLikelihood
and aExactGP
. It also cannot be used in conjunction with stochastic optimization.- Parameters:
likelihood (GaussianLikelihood) – The Gaussian likelihood for the model
model (ExactGP) – The exact GP model
Example
>>> # model is a gpytorch.models.ExactGP >>> # likelihood is a gpytorch.likelihoods.Likelihood >>> loocv = gpytorch.mlls.LeaveOneOutPseudoLikelihood(likelihood, model) >>> >>> output = model(train_x) >>> loss = -loocv(output, train_y) >>> loss.backward()
- forward(function_dist, target, *params)[source]¶
Computes the leave one out likelihood given \(p(\mathbf f)\) and \(\mathbf y\)
- Parameters:
output (MultivariateNormal) – the outputs of the latent function (the
GP
)target (Tensor) – \(\mathbf y\) The target values
kwargs (dict) – Additional arguments to pass to the likelihood’s forward function.
function_dist (MultivariateNormal) –
target –
- Return type:
Approximate GP Inference¶
These are MLLs for use with ApproximateGP
modules. They are designed for
when exact inference is intractable (either when the likelihood is non-Gaussian likelihood, or when
there is too much data for an ExactGP model).
VariationalELBO¶
- class gpytorch.mlls.VariationalELBO(likelihood, model, num_data, beta=1.0, combine_terms=True)[source]¶
The variational evidence lower bound (ELBO). This is used to optimize variational Gaussian processes (with or without stochastic optimization).
\[\begin{split}\begin{align*} \mathcal{L}_\text{ELBO} &= \mathbb{E}_{p_\text{data}( y, \mathbf x )} \left[ \mathbb{E}_{p(f \mid \mathbf u, \mathbf x) q(\mathbf u)} \left[ \log p( y \! \mid \! f) \right] \right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right] \\ &\approx \sum_{i=1}^N \mathbb{E}_{q( f_i)} \left[ \log p( y_i \! \mid \! f_i) \right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right] \end{align*}\end{split}\]where \(N\) is the number of datapoints, \(q(\mathbf u)\) is the variational distribution for the inducing function values, \(q(f_i)\) is the marginal of \(p(f_i \mid \mathbf u, \mathbf x_i) q(\mathbf u)\), and \(p(\mathbf u)\) is the prior distribution for the inducing function values.
\(\beta\) is a scaling constant that reduces the regularization effect of the KL divergence. Setting \(\beta=1\) (default) results in the true variational ELBO.
For more information on this derivation, see Scalable Variational Gaussian Process Classification (Hensman et al., 2015).
- Parameters:
likelihood (Likelihood) – The likelihood for the model
model (ApproximateGP) – The approximate GP model
num_data (int) – The total number of training data points (necessary for SGD)
beta (float) – (optional, default=1.) A multiplicative factor for the KL divergence term. Setting it to 1 (default) recovers true variational inference (as derived in Scalable Variational Gaussian Process Classification). Setting it to anything less than 1 reduces the regularization effect of the model (similarly to what was proposed in the beta-VAE paper).
combine_terms (bool) – (default=True): Whether or not to sum the expected NLL with the KL terms (default True)
Example
>>> # model is a gpytorch.models.ApproximateGP >>> # likelihood is a gpytorch.likelihoods.Likelihood >>> mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=100, beta=0.5) >>> >>> output = model(train_x) >>> loss = -mll(output, train_y) >>> loss.backward()
- forward(variational_dist_f, target, **kwargs)[source]¶
Computes the Variational ELBO given \(q(\mathbf f)\) and \(\mathbf y\). Calling this function will call the likelihood’s
expected_log_prob()
function.- Parameters:
variational_dist_f (MultivariateNormal) – \(q(\mathbf f)\) the outputs of the latent function (the
gpytorch.models.ApproximateGP
)target (torch.Tensor) – \(\mathbf y\) The target values
kwargs – Additional arguments passed to the likelihood’s
expected_log_prob()
function.
- Return type:
- Returns:
Variational ELBO. Output shape corresponds to batch shape of the model/input data.
PredictiveLogLikelihood¶
- class gpytorch.mlls.PredictiveLogLikelihood(likelihood, model, num_data, beta=1.0, combine_terms=True)[source]¶
An alternative objective function for approximate GPs, proposed in Jankowiak et al., 2020. It typically produces better predictive variances than the
gpytorch.mlls.VariationalELBO
objective.\[\begin{split}\begin{align*} \mathcal{L}_\text{ELBO} &= \mathbb{E}_{p_\text{data}( y, \mathbf x )} \left[ \log p( y \! \mid \! \mathbf x) \right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right] \\ &\approx \sum_{i=1}^N \log \mathbb{E}_{q(\mathbf u)} \left[ \int p( y_i \! \mid \! f_i) p(f_i \! \mid \! \mathbf u, \mathbf x_i) \: d f_i \right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right] \end{align*}\end{split}\]where \(N\) is the total number of datapoints, \(q(\mathbf u)\) is the variational distribution for the inducing function values, and \(p(\mathbf u)\) is the prior distribution for the inducing function values.
\(\beta\) is a scaling constant that reduces the regularization effect of the KL divergence. Setting \(\beta=1\) (default) results in an objective that can be motivated by a connection to Stochastic Expectation Propagation (see Jankowiak et al., 2020 for details).
Note
This objective is very similar to the variational ELBO. The only difference is that the \(log\) occurs outside the expectation \(\mathbb{E}_{q(\mathbf u)}\). This difference results in very different predictive performance (see Jankowiak et al., 2020).
- Parameters:
likelihood (Likelihood) – The likelihood for the model
model (ApproximateGP) – The approximate GP model
num_data (int) – The total number of training data points (necessary for SGD)
beta (float) – (optional, default=1.) A multiplicative factor for the KL divergence term. Setting it to anything less than 1 reduces the regularization effect of the model (similarly to what was proposed in the beta-VAE paper).
combine_terms (bool) – (default=True): Whether or not to sum the expected NLL with the KL terms (default True)
Example
>>> # model is a gpytorch.models.ApproximateGP >>> # likelihood is a gpytorch.likelihoods.Likelihood >>> mll = gpytorch.mlls.PredictiveLogLikelihood(likelihood, model, num_data=100, beta=0.5) >>> >>> output = model(train_x) >>> loss = -mll(output, train_y) >>> loss.backward()
- forward(approximate_dist_f, target, **kwargs)[source]¶
Computes the predictive cross entropy given \(q(\mathbf f)\) and \(\mathbf y\). Calling this function will call the likelihood’s
forward()
function.- Parameters:
variational_dist_f (MultivariateNormal) – \(q(\mathbf f)\) the outputs of the latent function (the
gpytorch.models.ApproximateGP
)target (torch.Tensor) – \(\mathbf y\) The target values
kwargs – Additional arguments passed to the likelihood’s
forward()
function.
- Return type:
- Returns:
Predictive log likelihood. Output shape corresponds to batch shape of the model/input data.
GammaRobustVariationalELBO¶
- class gpytorch.mlls.GammaRobustVariationalELBO(likelihood, model, gamma=1.03, *args, **kwargs)[source]¶
An alternative to the variational evidence lower bound (ELBO), proposed by Knoblauch, 2019. It is derived by replacing the log-likelihood term in the ELBO with a gamma divergence:
\[\begin{align*} \mathcal{L}_{\gamma} &= \sum_{i=1}^N \mathbb{E}_{q( \mathbf u)} \left[ -\frac{\gamma}{\gamma - 1} \frac{ p( y_i \! \mid \! \mathbf u, x_i)^{\gamma - 1} }{ \int p(y \mid \mathbf u, x_i)^{\gamma} \: dy } \right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right] \end{align*}\]where \(N\) is the number of datapoints, \(\gamma\) is a hyperparameter, \(q(\mathbf u)\) is the variational distribution for the inducing function values, and \(p(\mathbf u)\) is the prior distribution for the inducing function values.
\(\beta\) is a scaling constant for the KL divergence.
Note
This module will only work with
GaussianLikelihood
.- Parameters:
likelihood (GaussianLikelihood) – The likelihood for the model
model (ApproximateGP) – The approximate GP model
num_data (int) – The total number of training data points (necessary for SGD)
beta (float) – (optional, default=1.) A multiplicative factor for the KL divergence term. Setting it to anything less than 1 reduces the regularization effect of the model (similarly to what was proposed in the beta-VAE paper).
gamma (float) – (optional, default=1.03) The \(\gamma\)-divergence hyperparameter.
combine_terms (bool) – (default=True): Whether or not to sum the expected NLL with the KL terms (default True)
Example
>>> # model is a gpytorch.models.ApproximateGP >>> # likelihood is a gpytorch.likelihoods.Likelihood >>> mll = gpytorch.mlls.GammaRobustVariationalELBO(likelihood, model, num_data=100, beta=0.5, gamma=1.03) >>> >>> output = model(train_x) >>> loss = -mll(output, train_y) >>> loss.backward()
DeepApproximateMLL¶
- class gpytorch.mlls.DeepApproximateMLL(base_mll)[source]¶
A wrapper to make a GPyTorch approximate marginal log likelihoods compatible with Deep GPs.
Example
>>> deep_mll = gpytorch.mlls.DeepApproximateMLL( >>> gpytorch.mlls.VariationalELBO(likelihood, model, num_data=1000) >>> )
- Parameters:
base_mll (_ApproximateMarginalLogLikelihood) – The base approximate MLL
Modifications to Objective Functions¶
- class gpytorch.mlls.AddedLossTerm[source]¶
AddedLossTerms are registered onto GPyTorch models (or their children gpytorch.Modules).
If a model (or any of its children modules) has an added loss term, then all optimization objective functions (e.g.
ExactMarginalLogLikelihood
,VariationalELBO
, etc.) will be ammended to include an additive term defined by theloss()
method.As an example, consider the following toy AddedLossTerm that adds a random number to any objective function:
class RandomNumberAddedLoss # Adds a random number ot the loss def __init__(self, dtype, device): self.dtype, self.device = dtype, device def loss(self): # This dynamically defines the added loss term return torch.randn(torch.Size([]), dtype=self.dtype, device=self.device) class MyExactGP(gpytorch.ExactGP): def __init__(self, train_x, train_y): super().__init__(train_x, train_y, gpytorch.likelihood.GaussianLikelihood()) self.mean_module = gpytorch.means.ZeroMean() self.covar_module = gpytorch.kernels.RBFKernel() # Create the added loss term self.register_added_loss_term("random_added_loss") def forward(self, x): # Update loss term new_added_loss_term = RandomNumberAddedLoss(dtype=x.dtype, device=x.device) self.update_added_loss_term("random_added_loss", new_added_loss_term) # Run the remainder of the forward method return gpytorch.distribution.MultivariateNormal(self.mean_module(x), self.covar_module(x)) train_x = torch.randn(100, 2) train_y = torch.randn(100) model = MyExactGP(train_x, train_y) model.train() mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model) mll(model(train_x), train_y) # Returns log marginal likelihood + a random number
To use an AddedLossTerm:
A model (or a child module where the AddedLossTerm should live) should register an additive loss term with the
register_added_loss_term()
method. All AddedLossTerms have an identifying name associated with them.The
forward()
function of the model (or the child module) should instantiate the appropriate AddedLossTerm, calling theupdate_added_loss_term()
method.
InducingPointKernelAddedLossTerm¶
- class gpytorch.mlls.InducingPointKernelAddedLossTerm(prior_dist, variational_dist, likelihood)[source]¶
An added loss term that computes the additional “regularization trace term” of the SGPR objective function.
\[-\frac{1}{2 \sigma^2} \text{Tr} \left( \mathbf K_{\mathbf X \mathbf X} - \mathbf Q \right)\]where \(\mathbf Q = \mathbf K_{\mathbf X \mathbf Z} \mathbf K_{\mathbf Z \mathbf Z}^{-1} \mathbf K_{\mathbf Z \mathbf X}\) is the Nystrom approximation of \(\mathbf K_{\mathbf X \mathbf X}\) given by inducing points \(\mathbf Z\), and \(\sigma^2\) is the observational noise of the Gaussian likelihood.
See Titsias, 2009, Eq. 9 for more more information.
- Parameters:
prior_dist (MultivariateNormal) – A multivariate normal \(\mathcal N ( \mathbf 0, \mathbf K_{\mathbf X \mathbf X} )\) with covariance matrix \(\mathbf K_{\mathbf X \mathbf X}\).
variational_dist (MultivariateNormal) – A multivariate normal \(\mathcal N ( \mathbf 0, \mathbf Q\) with covariance matrix \(\mathbf Q = \mathbf K_{\mathbf X \mathbf Z} \mathbf K_{\mathbf Z \mathbf Z}^{-1} \mathbf K_{\mathbf Z \mathbf X}\).
likelihood (GaussianLikelihood) – The Gaussian likelihood with observational noise \(\sigma^2\).
KLGaussianAddedLossTerm¶
- class gpytorch.mlls.KLGaussianAddedLossTerm(q_x, p_x, n, data_dim)[source]¶
This class is used by variational GPLVM models. It adds the KL divergence between two multivariate Gaussian distributions: scaled by the size of the data and the number of output dimensions.
\[D_\text{KL} \left( q(\mathbf x) \Vert p(\mathbf x) \right)\]- Parameters:
q_x (MultivariateNormal) – The MVN distribution \(q(\mathbf x)\).
p_x (MultivariateNormal) – The MVN distribution \(p(\mathbf x)\).
n (int) – Size of the latent space.
data_dim (int) – Dimensionality of the \(\mathbf Y\) values.