Source code for gpytorch.mlls.deep_approximate_mll

from ._approximate_mll import _ApproximateMarginalLogLikelihood


[docs]class DeepApproximateMLL(_ApproximateMarginalLogLikelihood): """ A wrapper to make a GPyTorch approximate marginal log likelihoods compatible with Deep GPs. Example: >>> deep_mll = gpytorch.mlls.DeepApproximateMLL( >>> gpytorch.mlls.VariationalELBO(likelihood, model, num_data=1000) >>> ) :param ~gpytorch.mlls._ApproximateMarginalLogLikelihood base_mll: The base approximate MLL """ def __init__(self, base_mll): if not base_mll.combine_terms: raise ValueError( "The base marginal log likelihood object should combine terms " "when used in conjunction with a DeepApproximateMLL." ) super().__init__(base_mll.likelihood, base_mll.model, num_data=base_mll.num_data, beta=base_mll.beta) self.base_mll = base_mll def _log_likelihood_term(self, approximate_dist_f, target, **kwargs): return self.base_mll._log_likelihood_term(approximate_dist_f, target, **kwargs).mean(0) def forward(self, approximate_dist_f, target, **kwargs): return self.base_mll.forward(approximate_dist_f, target, **kwargs).mean(0)