#!/usr/bin/env python3
from ..distributions import MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from .marginal_log_likelihood import MarginalLogLikelihood
[docs]class ExactMarginalLogLikelihood(MarginalLogLikelihood):
"""
The exact marginal log likelihood (MLL) for an exact Gaussian process with a
Gaussian likelihood.
.. note::
This module will not work with anything other than a :obj:`~gpytorch.likelihoods.GaussianLikelihood`
and a :obj:`~gpytorch.models.ExactGP`. It also cannot be used in conjunction with
stochastic optimization.
:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model
:param ~gpytorch.models.ExactGP model: The exact GP model
Example:
>>> # model is a gpytorch.models.ExactGP
>>> # likelihood is a gpytorch.likelihoods.Likelihood
>>> mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
>>>
>>> output = model(train_x)
>>> loss = -mll(output, train_y)
>>> loss.backward()
"""
def __init__(self, likelihood, model):
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("Likelihood must be Gaussian for exact inference")
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model)
def _add_other_terms(self, res, params):
# Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models)
for added_loss_term in self.model.added_loss_terms():
res = res.add(added_loss_term.loss(*params))
# Add log probs of priors on the (functions of) parameters
res_ndim = res.ndim
for name, module, prior, closure, _ in self.model.named_priors():
prior_term = prior.log_prob(closure(module))
res.add_(prior_term.view(*prior_term.shape[:res_ndim], -1).sum(dim=-1))
return res
[docs] def forward(self, function_dist, target, *params):
r"""
Computes the MLL given :math:`p(\mathbf f)` and :math:`\mathbf y`.
:param ~gpytorch.distributions.MultivariateNormal function_dist: :math:`p(\mathbf f)`
the outputs of the latent function (the :obj:`gpytorch.models.ExactGP`)
:param torch.Tensor target: :math:`\mathbf y` The target values
:rtype: torch.Tensor
:return: Exact MLL. Output shape corresponds to batch shape of the model/input data.
"""
if not isinstance(function_dist, MultivariateNormal):
raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables")
# Get the log prob of the marginal distribution
output = self.likelihood(function_dist, *params)
res = output.log_prob(target)
res = self._add_other_terms(res, params)
# Scale by the amount of data we have
num_data = function_dist.event_shape.numel()
return res.div_(num_data)