Source code for gpytorch.mlls.exact_marginal_log_likelihood

#!/usr/bin/env python3

from ..distributions import MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from .marginal_log_likelihood import MarginalLogLikelihood

[docs]class ExactMarginalLogLikelihood(MarginalLogLikelihood): """ The exact marginal log likelihood (MLL) for an exact Gaussian process with a Gaussian likelihood. .. note:: This module will not work with anything other than a :obj:`~gpytorch.likelihoods.GaussianLikelihood` and a :obj:`~gpytorch.models.ExactGP`. It also cannot be used in conjunction with stochastic optimization. :param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model :param ~gpytorch.models.ExactGP model: The exact GP model Example: >>> # model is a gpytorch.models.ExactGP >>> # likelihood is a gpytorch.likelihoods.Likelihood >>> mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) >>> >>> output = model(train_x) >>> loss = -mll(output, train_y) >>> loss.backward() """ def __init__(self, likelihood, model): if not isinstance(likelihood, _GaussianLikelihoodBase): raise RuntimeError("Likelihood must be Gaussian for exact inference") super(ExactMarginalLogLikelihood, self).__init__(likelihood, model) def _add_other_terms(self, res, params): # Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models) for added_loss_term in self.model.added_loss_terms(): res = res.add(added_loss_term.loss(*params)) # Add log probs of priors on the (functions of) parameters res_ndim = res.ndim for name, module, prior, closure, _ in self.model.named_priors(): prior_term = prior.log_prob(closure(module)) res.add_(prior_term.view(*prior_term.shape[:res_ndim], -1).sum(dim=-1)) return res
[docs] def forward(self, function_dist, target, *params): r""" Computes the MLL given :math:`p(\mathbf f)` and :math:`\mathbf y`. :param ~gpytorch.distributions.MultivariateNormal function_dist: :math:`p(\mathbf f)` the outputs of the latent function (the :obj:`gpytorch.models.ExactGP`) :param torch.Tensor target: :math:`\mathbf y` The target values :rtype: torch.Tensor :return: Exact MLL. Output shape corresponds to batch shape of the model/input data. """ if not isinstance(function_dist, MultivariateNormal): raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables") # Get the log prob of the marginal distribution output = self.likelihood(function_dist, *params) res = output.log_prob(target) res = self._add_other_terms(res, params) # Scale by the amount of data we have num_data = function_dist.event_shape.numel() return res.div_(num_data)