Source code for gpytorch.variational.ciq_variational_strategy

#!/usr/bin/env python3

from typing import Optional, Tuple

import torch
from linear_operator import to_linear_operator
from linear_operator.operators import DiagLinearOperator, LinearOperator, MatmulLinearOperator, SumLinearOperator
from linear_operator.utils import linear_cg
from torch import Tensor
from torch.autograd.function import FunctionCtx

from .. import settings
from ..distributions import Delta, Distribution, MultivariateNormal
from ..module import Module
from ..utils.memoize import cached
from ._variational_strategy import _VariationalStrategy
from .natural_variational_distribution import NaturalVariationalDistribution

class _NgdInterpTerms(torch.autograd.Function):
    This function takes in

        - the kernel interpolation term K_ZZ^{-1/2} k_ZX
        - the natural parameters of the variational distribution

    and returns

        - the predictive distribution mean/covariance
        - the inducing KL divergence KL( q(u) || p(u))

    However, the gradients will be with respect to the **cannonical parameters**
    of the variational distribution, rather than the **natural parameters**.
    This corresponds to performing natural gradient descent on the variational distribution.

    def forward(
        ctx: FunctionCtx,
        interp_term: torch.Tensor,
        natural_vec: torch.Tensor,
        natural_mat: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Compute precision
        prec = natural_mat.mul(-2.0)
        diag = prec.diagonal(dim1=-1, dim2=-2).unsqueeze(-1)

        # Make sure that interp_term and natural_vec are the same batch shape
        batch_shape = torch.broadcast_shapes(interp_term.shape[:-2], natural_vec.shape[:-1])
        expanded_interp_term = interp_term.expand(*batch_shape, *interp_term.shape[-2:])
        expanded_natural_vec = natural_vec.expand(*batch_shape, natural_vec.size(-1))

        # Compute necessary solves with the precision. We need
        # m = expec_vec = S * natural_vec
        # S K^{-1/2} k
        solves = linear_cg(
  [expanded_natural_vec.unsqueeze(-1), expanded_interp_term], dim=-1),
            tolerance=min(settings.eval_cg_tolerance.value(), settings.cg_tolerance.value()),
            preconditioner=lambda x: x / diag,
        expec_vec = solves[..., 0]
        s_times_interp_term = solves[..., 1:]

        # Compute the interpolated mean
        # k^T K^{-1/2} m
        interp_mean = (s_times_interp_term.transpose(-1, -2) @ natural_vec.unsqueeze(-1)).squeeze(-1)

        # Compute the interpolated variance
        # k^T K^{-1/2} S K^{-1/2} k = k^T K^{-1/2} (expec_mat - expec_vec expec_vec^T) K^{-1/2} k
        interp_var = (s_times_interp_term * interp_term).sum(dim=-2)

        # Let's not bother actually computing the KL-div in the foward pass
        # 1/2 ( -log | S | + tr(S) + m^T m - len(m) )
        # = 1/2 ( -log | expec_mat - expec_vec expec_vec^T | + tr(expec_mat) - len(m) )
        kl_div = torch.zeros_like(interp_mean[..., 0])

        # We're done!
        ctx.save_for_backward(interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec)
        return interp_mean, interp_var, kl_div

    def backward(
        ctx: FunctionCtx, interp_mean_grad: torch.Tensor, interp_var_grad: torch.Tensor, kl_div_grad: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None]:
        # Get the saved terms
        interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec = ctx.saved_tensors

        # Expand data-depenedent gradients
        interp_mean_grad = interp_mean_grad.unsqueeze(-2)
        interp_var_grad = interp_var_grad.unsqueeze(-2)

        # Compute gradient of interp term (K^{-1/2} k)
        # interp_mean component: m
        # interp_var component: S K^{-1/2} k
        # kl component: 0
        interp_term_grad = (interp_var_grad * s_times_interp_term).mul(2.0) + (
            interp_mean_grad * expec_vec.unsqueeze(-1)

        # Compute gradient of expected vector (m)
        # interp_mean component: K^{-1/2} k
        # interp_var component: (k^T K^{-1/2} m) K^{-1/2} k
        # kl component: S^{-1} m
        expec_vec_grad = (
            (interp_var_grad * interp_mean.unsqueeze(-2) * interp_term).sum(dim=-1).mul(-2)
            + (interp_mean_grad * interp_term).sum(dim=-1)
            + (kl_div_grad.unsqueeze(-1) * natural_vec)

        # Compute gradient of expected matrix (mm^T + S)
        # interp_mean component: 0
        # interp_var component: K^{-1/2} k k^T K^{-1/2}
        # kl component: 1/2 ( I - S^{-1} )
        eye = torch.eye(expec_vec.size(-1), device=expec_vec.device, dtype=expec_vec.dtype)
        expec_mat_grad = torch.add(
            (interp_var_grad * interp_term) @ interp_term.transpose(-1, -2),
            (kl_div_grad.unsqueeze(-1).unsqueeze(-1) * (eye - prec).mul(0.5)),

        # We're done!
        return interp_term_grad, expec_vec_grad, expec_mat_grad, None  # Extra "None" for the kwarg

[docs]class CiqVariationalStrategy(_VariationalStrategy): r""" Similar to :class:`~gpytorch.variational.VariationalStrategy`, except the whitening operation is performed using Contour Integral Quadrature rather than Cholesky (see `Pleiss et al. (2020)`_ for more info). See the `CIQ-SVGP tutorial`_ for an example. Contour Integral Quadrature uses iterative matrix-vector multiplication to approximate the :math:`\mathbf K_{\mathbf Z \mathbf Z}^{-1/2}` matrix used for the whitening operation. This can be more efficient than the standard variational strategy for large numbers of inducing points (e.g. :math:`M > 1000`) or when the inducing points have structure (e.g. they lie on an evenly-spaced grid). .. note:: It is recommended that this object is used in conjunction with :obj:`~gpytorch.variational.NaturalVariationalDistribution` and `natural gradient descent`_. :param model: Model this strategy is applied to. Typically passed in when the VariationalStrategy is created in the __init__ method of the user defined model. :param inducing_points: Tensor containing a set of inducing points to use for variational inference. :param variational_distribution: A VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)` :param learn_inducing_locations: (Default True): Whether or not the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they parameters of the model). :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability .. _Pleiss et al. (2020): .. _CIQ-SVGP tutorial: examples/04_Variational_and_Approximate_GPs/SVGP_CIQ.html .. _natural gradient descent: examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.html """ def _ngd(self) -> bool: return isinstance(self._variational_distribution, NaturalVariationalDistribution) @property @cached(name="prior_distribution_memo") def prior_distribution(self) -> MultivariateNormal: zeros = torch.zeros( self._variational_distribution.shape(), dtype=self._variational_distribution.dtype, device=self._variational_distribution.device, ) ones = torch.ones_like(zeros) res = MultivariateNormal(zeros, DiagLinearOperator(ones)) return res @property @cached(name="variational_distribution_memo") def variational_distribution(self) -> Distribution: if self._ngd(): raise RuntimeError( "Variational distribution for NGD-CIQ should be computed during forward calls. " "This is probably a bug in GPyTorch." ) return super().variational_distribution def forward( self, x: torch.Tensor, inducing_points: torch.Tensor, inducing_values: torch.Tensor, variational_inducing_covar: Optional[LinearOperator] = None, *params, **kwargs, ) -> MultivariateNormal: # Compute full prior distribution full_inputs =[inducing_points, x], dim=-2) full_output = self.model.forward(full_inputs, *params, **kwargs) full_covar = full_output.lazy_covariance_matrix # Covariance terms num_induc = inducing_points.size(-2) test_mean = full_output.mean[..., num_induc:] induc_induc_covar = full_covar[..., :num_induc, :num_induc].evaluate_kernel().add_jitter(self.jitter_val) induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense() data_data_covar = full_covar[..., num_induc:, num_induc:].add_jitter(self.jitter_val) # Compute interpolation terms # K_XZ K_ZZ^{-1} \mu_z # K_XZ K_ZZ^{-1/2} \mu_Z with settings.max_preconditioner_size(0): # Turn off preconditioning for CIQ interp_term = to_linear_operator(induc_induc_covar).sqrt_inv_matmul(induc_data_covar) # Compute interpolated mean and variance terms # We have separate computation rules for NGD versus standard GD if self._ngd(): interp_mean, interp_var, kl_div = _NgdInterpTerms.apply( interp_term, self._variational_distribution.natural_vec, self._variational_distribution.natural_mat, ) # Compute the covariance of q(f) predictive_var = data_data_covar.diagonal(dim1=-1, dim2=-2) - interp_term.pow(2).sum(dim=-2) + interp_var predictive_var = torch.clamp_min(predictive_var, settings.min_variance.value(predictive_var.dtype)) predictive_covar = DiagLinearOperator(predictive_var) # Also compute and cache the KL divergence if not hasattr(self, "_memoize_cache"): self._memoize_cache = dict() self._memoize_cache["kl"] = kl_div else: # Compute interpolated mean term interp_mean = torch.matmul( interp_term.transpose(-1, -2), (inducing_values - self.prior_distribution.mean).unsqueeze(-1) ).squeeze(-1) # Compute the covariance of q(f) middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1) if variational_inducing_covar is not None: middle_term = SumLinearOperator(variational_inducing_covar, middle_term) predictive_covar = SumLinearOperator( data_data_covar.add_jitter(self.jitter_val), MatmulLinearOperator(interp_term.transpose(-1, -2), middle_term @ interp_term), ) # Compute the mean of q(f) # k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X predictive_mean = interp_mean + test_mean # Return the distribution return MultivariateNormal(predictive_mean, predictive_covar)
[docs] def kl_divergence(self) -> Tensor: r""" Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)` and the prior inducing distribution :math:`p(\mathbf u)`. :rtype: torch.Tensor """ if self._ngd(): if hasattr(self, "_memoize_cache") and "kl" in self._memoize_cache: return self._memoize_cache["kl"] else: raise RuntimeError( "KL divergence for NGD-CIQ should be computed during forward calls." "This is probably a bug in GPyTorch." ) else: return super().kl_divergence()
def __call__(self, x: torch.Tensor, prior: bool = False, *params, **kwargs) -> MultivariateNormal: # This is mostly the same as _VariationalStrategy.__call__() # but with special rules for natural gradient descent (to prevent O(M^3) computation) # If we're in prior mode, then we're done! if prior: return self.model.forward(x) # Delete previously cached items from the training distribution if self._clear_cache() # (Maybe) initialize variational distribution if not self.variational_params_initialized.item(): if self._ngd(): noise = torch.randn_like(self.prior_distribution.mean).mul_(1e-3) eye = torch.eye(noise.size(-1), dtype=noise.dtype, device=noise.device).mul(-0.5) self.variational_params_initialized.fill_(1) else: prior_dist = self.prior_distribution self._variational_distribution.initialize_variational_distribution(prior_dist) self.variational_params_initialized.fill_(1) # Ensure inducing_points and x are the same size inducing_points = self.inducing_points if inducing_points.shape[:-2] != x.shape[:-2]: x, inducing_points = self._expand_inputs(x, inducing_points) # Get q(f) if self._ngd(): return Module.__call__( self, x, inducing_points, inducing_values=None, variational_inducing_covar=None, *params, **kwargs, ) else: # Get p(u)/q(u) variational_dist_u = self.variational_distribution if isinstance(variational_dist_u, MultivariateNormal): return Module.__call__( self, x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=variational_dist_u.lazy_covariance_matrix, **kwargs, ) elif isinstance(variational_dist_u, Delta): return Module.__call__( self, x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=None, **kwargs, ) else: raise RuntimeError( f"Invalid variational distribuition ({type(variational_dist_u)}). " "Expected a multivariate normal or a delta distribution." )