Source code for gpytorch.mlls.exact_marginal_log_likelihood

#!/usr/bin/env python3

from ..distributions import MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from .marginal_log_likelihood import MarginalLogLikelihood


[docs]class ExactMarginalLogLikelihood(MarginalLogLikelihood): """ The exact marginal log likelihood (MLL) for an exact Gaussian process with a Gaussian likelihood. .. note:: This module will not work with anything other than a :obj:`~gpytorch.likelihoods.GaussianLikelihood` and a :obj:`~gpytorch.models.ExactGP`. It also cannot be used in conjunction with stochastic optimization. :param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model :param ~gpytorch.models.ExactGP model: The exact GP model Example: >>> # model is a gpytorch.models.ExactGP >>> # likelihood is a gpytorch.likelihoods.Likelihood >>> mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) >>> >>> output = model(train_x) >>> loss = -mll(output, train_y) >>> loss.backward() """ def __init__(self, likelihood, model): if not isinstance(likelihood, _GaussianLikelihoodBase): raise RuntimeError("Likelihood must be Gaussian for exact inference") super(ExactMarginalLogLikelihood, self).__init__(likelihood, model) def _add_other_terms(self, res, params): # Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models) for added_loss_term in self.model.added_loss_terms(): res = res.add(added_loss_term.loss(*params)) # Add log probs of priors on the (functions of) parameters for _, prior, closure, _ in self.named_priors(): res.add_(prior.log_prob(closure()).sum()) return res
[docs] def forward(self, function_dist, target, *params): r""" Computes the MLL given :math:`p(\mathbf f)` and :math:`\mathbf y`. :param ~gpytorch.distributions.MultivariateNormal function_dist: :math:`p(\mathbf f)` the outputs of the latent function (the :obj:`gpytorch.models.ExactGP`) :param torch.Tensor target: :math:`\mathbf y` The target values :rtype: torch.Tensor :return: Exact MLL. Output shape corresponds to batch shape of the model/input data. """ if not isinstance(function_dist, MultivariateNormal): raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables") # Get the log prob of the marginal distribution output = self.likelihood(function_dist, *params) res = output.log_prob(target) res = self._add_other_terms(res, params) # Scale by the amount of data we have num_data = target.size(-1) return res.div_(num_data)
def pyro_factor(self, output, target, *params): import pyro mll = self(output, target, *params) pyro.factor("gp_mll", mll) return mll