Source code for gpytorch.variational._variational_strategy

#!/usr/bin/env python3

from abc import ABC, abstractproperty

import torch

from .. import settings
from ..distributions import Delta, MultivariateNormal
from ..module import Module
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.memoize import cached, clear_cache_hook

[docs]class _VariationalStrategy(Module, ABC): """ Abstract base class for all Variational Strategies. """ def __init__(self, model, inducing_points, variational_distribution, learn_inducing_locations=True): super().__init__() # Model object.__setattr__(self, "model", model) # Inducing points inducing_points = inducing_points.clone() if inducing_points.dim() == 1: inducing_points = inducing_points.unsqueeze(-1) if learn_inducing_locations: self.register_parameter(name="inducing_points", parameter=torch.nn.Parameter(inducing_points)) else: self.register_buffer("inducing_points", inducing_points) # Variational distribution self._variational_distribution = variational_distribution self.register_buffer("variational_params_initialized", torch.tensor(0)) def _clear_cache(self): clear_cache_hook(self) def _expand_inputs(self, x, inducing_points): """ Pre-processing step in __call__ to make x the same batch_shape as the inducing points """ batch_shape = _mul_broadcast_shape(inducing_points.shape[:-2], x.shape[:-2]) inducing_points = inducing_points.expand(*batch_shape, *inducing_points.shape[-2:]) x = x.expand(*batch_shape, *x.shape[-2:]) return x, inducing_points @abstractproperty @cached(name="prior_distribution_memo") def prior_distribution(self): r""" The :func:`~gpytorch.variational.VariationalStrategy.prior_distribution` method determines how to compute the GP prior distribution of the inducing points, e.g. :math:`p(u) \sim N(\mu(X_u), K(X_u, X_u))`. Most commonly, this is done simply by calling the user defined GP prior on the inducing point data directly. :rtype: :obj:`~gpytorch.distributions.MultivariateNormal` :return: The distribution :math:`p( \mathbf u)` """ raise NotImplementedError @property @cached(name="variational_distribution_memo") def variational_distribution(self): return self._variational_distribution()
[docs] def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None): r""" The :func:`~gpytorch.variational.VariationalStrategy.forward` method determines how to marginalize out the inducing point function values. Specifically, forward defines how to transform a variational distribution over the inducing point values, :math:`q(u)`, in to a variational distribution over the function values at specified locations x, :math:`q(f|x)`, by integrating :math:`\int p(f|x, u)q(u)du` :param torch.Tensor x: Locations :math:`\mathbf X` to get the variational posterior of the function values at. :param torch.Tensor inducing_points: Locations :math:`\mathbf Z` of the inducing points :param torch.Tensor inducing_values: Samples of the inducing function values :math:`\mathbf u` (or the mean of the distribution :math:`q(\mathbf u)` if q is a Gaussian. :param ~gpytorch.lazy.LazyTensor variational_inducing_covar: If the distribuiton :math:`q(\mathbf u)` is Gaussian, then this variable is the covariance matrix of that Gaussian. Otherwise, it will be :attr:`None`. :rtype: :obj:`~gpytorch.distributions.MultivariateNormal` :return: The distribution :math:`q( \mathbf f(\mathbf X))` """ raise NotImplementedError
[docs] def kl_divergence(self): r""" Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)` and the prior inducing distribution :math:`p(\mathbf u)`. :rtype: torch.Tensor """ with settings.max_preconditioner_size(0): kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution) return kl_divergence
def __call__(self, x, prior=False): # If we're in prior mode, then we're done! if prior: return self.model.forward(x) # Delete previously cached items from the training distribution if self._clear_cache() # (Maybe) initialize variational distribution if not self.variational_params_initialized.item(): prior_dist = self.prior_distribution self._variational_distribution.initialize_variational_distribution(prior_dist) self.variational_params_initialized.fill_(1) # Ensure inducing_points and x are the same size inducing_points = self.inducing_points if inducing_points.shape[:-2] != x.shape[:-2]: x, inducing_points = self._expand_inputs(x, inducing_points) # Get p(u)/q(u) variational_dist_u = self.variational_distribution # Get q(f) if isinstance(variational_dist_u, MultivariateNormal): return super().__call__( x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=variational_dist_u.lazy_covariance_matrix, ) elif isinstance(variational_dist_u, Delta): return super().__call__( x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=None ) else: raise RuntimeError( f"Invalid variational distribuition ({type(variational_dist_u)}). " "Expected a multivariate normal or a delta distribution." )