# Source code for gpytorch.mlls.exact_marginal_log_likelihood

#!/usr/bin/env python3

from ..distributions import MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from .marginal_log_likelihood import MarginalLogLikelihood

[docs]class ExactMarginalLogLikelihood(MarginalLogLikelihood):
"""
The exact marginal log likelihood (MLL) for an exact Gaussian process with a
Gaussian likelihood.

.. note::
This module will not work with anything other than a :obj:~gpytorch.likelihoods.GaussianLikelihood
and a :obj:~gpytorch.models.ExactGP. It also cannot be used in conjunction with
stochastic optimization.

:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model
:param ~gpytorch.models.ExactGP model: The exact GP model

Example:
>>> # model is a gpytorch.models.ExactGP
>>> # likelihood is a gpytorch.likelihoods.Likelihood
>>> mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
>>>
>>> output = model(train_x)
>>> loss = -mll(output, train_y)
>>> loss.backward()
"""

def __init__(self, likelihood, model):
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("Likelihood must be Gaussian for exact inference")
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model)

# Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models)

# Add log probs of priors on the (functions of) parameters
for _, prior, closure, _ in self.named_priors():

return res

[docs]    def forward(self, function_dist, target, *params):
r"""
Computes the MLL given :math:p(\mathbf f) and :math:\mathbf y.

:param ~gpytorch.distributions.MultivariateNormal function_dist: :math:p(\mathbf f)
the outputs of the latent function (the :obj:gpytorch.models.ExactGP)
:param torch.Tensor target: :math:\mathbf y The target values
:rtype: torch.Tensor
:return: Exact MLL. Output shape corresponds to batch shape of the model/input data.
"""
if not isinstance(function_dist, MultivariateNormal):
raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables")

# Get the log prob of the marginal distribution
output = self.likelihood(function_dist, *params)
res = output.log_prob(target)