#!/usr/bin/env python3
from typing import Optional, Tuple
import torch
from linear_operator import to_linear_operator
from linear_operator.operators import DiagLinearOperator, LinearOperator, MatmulLinearOperator, SumLinearOperator
from linear_operator.utils import linear_cg
from torch import Tensor
from torch.autograd.function import FunctionCtx
from .. import settings
from ..distributions import Delta, Distribution, MultivariateNormal
from ..module import Module
from ..utils.memoize import cached
from ._variational_strategy import _VariationalStrategy
from .natural_variational_distribution import NaturalVariationalDistribution
class _NgdInterpTerms(torch.autograd.Function):
"""
This function takes in
- the kernel interpolation term K_ZZ^{-1/2} k_ZX
- the natural parameters of the variational distribution
and returns
- the predictive distribution mean/covariance
- the inducing KL divergence KL( q(u) || p(u))
However, the gradients will be with respect to the **cannonical parameters**
of the variational distribution, rather than the **natural parameters**.
This corresponds to performing natural gradient descent on the variational distribution.
"""
@staticmethod
def forward(
ctx: FunctionCtx,
interp_term: torch.Tensor,
natural_vec: torch.Tensor,
natural_mat: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Compute precision
prec = natural_mat.mul(-2.0)
diag = prec.diagonal(dim1=-1, dim2=-2).unsqueeze(-1)
# Make sure that interp_term and natural_vec are the same batch shape
batch_shape = torch.broadcast_shapes(interp_term.shape[:-2], natural_vec.shape[:-1])
expanded_interp_term = interp_term.expand(*batch_shape, *interp_term.shape[-2:])
expanded_natural_vec = natural_vec.expand(*batch_shape, natural_vec.size(-1))
# Compute necessary solves with the precision. We need
# m = expec_vec = S * natural_vec
# S K^{-1/2} k
solves = linear_cg(
prec.matmul,
torch.cat([expanded_natural_vec.unsqueeze(-1), expanded_interp_term], dim=-1),
n_tridiag=0,
max_iter=settings.max_cg_iterations.value(),
tolerance=min(settings.eval_cg_tolerance.value(), settings.cg_tolerance.value()),
max_tridiag_iter=settings.max_lanczos_quadrature_iterations.value(),
preconditioner=lambda x: x / diag,
)
expec_vec = solves[..., 0]
s_times_interp_term = solves[..., 1:]
# Compute the interpolated mean
# k^T K^{-1/2} m
interp_mean = (s_times_interp_term.transpose(-1, -2) @ natural_vec.unsqueeze(-1)).squeeze(-1)
# Compute the interpolated variance
# k^T K^{-1/2} S K^{-1/2} k = k^T K^{-1/2} (expec_mat - expec_vec expec_vec^T) K^{-1/2} k
interp_var = (s_times_interp_term * interp_term).sum(dim=-2)
# Let's not bother actually computing the KL-div in the foward pass
# 1/2 ( -log | S | + tr(S) + m^T m - len(m) )
# = 1/2 ( -log | expec_mat - expec_vec expec_vec^T | + tr(expec_mat) - len(m) )
kl_div = torch.zeros_like(interp_mean[..., 0])
# We're done!
ctx.save_for_backward(interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec)
return interp_mean, interp_var, kl_div
@staticmethod
def backward(
ctx: FunctionCtx, interp_mean_grad: torch.Tensor, interp_var_grad: torch.Tensor, kl_div_grad: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None]:
# Get the saved terms
interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec = ctx.saved_tensors
# Expand data-depenedent gradients
interp_mean_grad = interp_mean_grad.unsqueeze(-2)
interp_var_grad = interp_var_grad.unsqueeze(-2)
# Compute gradient of interp term (K^{-1/2} k)
# interp_mean component: m
# interp_var component: S K^{-1/2} k
# kl component: 0
interp_term_grad = (interp_var_grad * s_times_interp_term).mul(2.0) + (
interp_mean_grad * expec_vec.unsqueeze(-1)
)
# Compute gradient of expected vector (m)
# interp_mean component: K^{-1/2} k
# interp_var component: (k^T K^{-1/2} m) K^{-1/2} k
# kl component: S^{-1} m
expec_vec_grad = (
(interp_var_grad * interp_mean.unsqueeze(-2) * interp_term).sum(dim=-1).mul(-2)
+ (interp_mean_grad * interp_term).sum(dim=-1)
+ (kl_div_grad.unsqueeze(-1) * natural_vec)
)
# Compute gradient of expected matrix (mm^T + S)
# interp_mean component: 0
# interp_var component: K^{-1/2} k k^T K^{-1/2}
# kl component: 1/2 ( I - S^{-1} )
eye = torch.eye(expec_vec.size(-1), device=expec_vec.device, dtype=expec_vec.dtype)
expec_mat_grad = torch.add(
(interp_var_grad * interp_term) @ interp_term.transpose(-1, -2),
(kl_div_grad.unsqueeze(-1).unsqueeze(-1) * (eye - prec).mul(0.5)),
)
# We're done!
return interp_term_grad, expec_vec_grad, expec_mat_grad, None # Extra "None" for the kwarg
[docs]class CiqVariationalStrategy(_VariationalStrategy):
r"""
Similar to :class:`~gpytorch.variational.VariationalStrategy`,
except the whitening operation is performed using Contour Integral Quadrature
rather than Cholesky (see `Pleiss et al. (2020)`_ for more info).
See the `CIQ-SVGP tutorial`_ for an example.
Contour Integral Quadrature uses iterative matrix-vector multiplication to approximate
the :math:`\mathbf K_{\mathbf Z \mathbf Z}^{-1/2}` matrix used for the whitening operation.
This can be more efficient than the standard variational strategy for large numbers
of inducing points (e.g. :math:`M > 1000`) or when the inducing points have structure
(e.g. they lie on an evenly-spaced grid).
.. note::
It is recommended that this object is used in conjunction with
:obj:`~gpytorch.variational.NaturalVariationalDistribution` and
`natural gradient descent`_.
:param model: Model this strategy is applied to.
Typically passed in when the VariationalStrategy is created in the
__init__ method of the user defined model.
:param inducing_points: Tensor containing a set of inducing
points to use for variational inference.
:param variational_distribution: A
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
:param learn_inducing_locations: (Default True): Whether or not
the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
parameters of the model).
:param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
.. _Pleiss et al. (2020):
https://arxiv.org/pdf/2006.11267.pdf
.. _CIQ-SVGP tutorial:
examples/04_Variational_and_Approximate_GPs/SVGP_CIQ.html
.. _natural gradient descent:
examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.html
"""
def _ngd(self) -> bool:
return isinstance(self._variational_distribution, NaturalVariationalDistribution)
@property
@cached(name="prior_distribution_memo")
def prior_distribution(self) -> MultivariateNormal:
zeros = torch.zeros(
self._variational_distribution.shape(),
dtype=self._variational_distribution.dtype,
device=self._variational_distribution.device,
)
ones = torch.ones_like(zeros)
res = MultivariateNormal(zeros, DiagLinearOperator(ones))
return res
@property
@cached(name="variational_distribution_memo")
def variational_distribution(self) -> Distribution:
if self._ngd():
raise RuntimeError(
"Variational distribution for NGD-CIQ should be computed during forward calls. "
"This is probably a bug in GPyTorch."
)
return super().variational_distribution
def forward(
self,
x: torch.Tensor,
inducing_points: torch.Tensor,
inducing_values: torch.Tensor,
variational_inducing_covar: Optional[LinearOperator] = None,
*params,
**kwargs,
) -> MultivariateNormal:
# Compute full prior distribution
full_inputs = torch.cat([inducing_points, x], dim=-2)
full_output = self.model.forward(full_inputs, *params, **kwargs)
full_covar = full_output.lazy_covariance_matrix
# Covariance terms
num_induc = inducing_points.size(-2)
test_mean = full_output.mean[..., num_induc:]
induc_induc_covar = full_covar[..., :num_induc, :num_induc].evaluate_kernel().add_jitter(self.jitter_val)
induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
data_data_covar = full_covar[..., num_induc:, num_induc:].add_jitter(self.jitter_val)
# Compute interpolation terms
# K_XZ K_ZZ^{-1} \mu_z
# K_XZ K_ZZ^{-1/2} \mu_Z
with settings.max_preconditioner_size(0): # Turn off preconditioning for CIQ
interp_term = to_linear_operator(induc_induc_covar).sqrt_inv_matmul(induc_data_covar)
# Compute interpolated mean and variance terms
# We have separate computation rules for NGD versus standard GD
if self._ngd():
interp_mean, interp_var, kl_div = _NgdInterpTerms.apply(
interp_term,
self._variational_distribution.natural_vec,
self._variational_distribution.natural_mat,
)
# Compute the covariance of q(f)
predictive_var = data_data_covar.diagonal(dim1=-1, dim2=-2) - interp_term.pow(2).sum(dim=-2) + interp_var
predictive_var = torch.clamp_min(predictive_var, settings.min_variance.value(predictive_var.dtype))
predictive_covar = DiagLinearOperator(predictive_var)
# Also compute and cache the KL divergence
if not hasattr(self, "_memoize_cache"):
self._memoize_cache = dict()
self._memoize_cache["kl"] = kl_div
else:
# Compute interpolated mean term
interp_mean = torch.matmul(
interp_term.transpose(-1, -2), (inducing_values - self.prior_distribution.mean).unsqueeze(-1)
).squeeze(-1)
# Compute the covariance of q(f)
middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
if variational_inducing_covar is not None:
middle_term = SumLinearOperator(variational_inducing_covar, middle_term)
predictive_covar = SumLinearOperator(
data_data_covar.add_jitter(self.jitter_val),
MatmulLinearOperator(interp_term.transpose(-1, -2), middle_term @ interp_term),
)
# Compute the mean of q(f)
# k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
predictive_mean = interp_mean + test_mean
# Return the distribution
return MultivariateNormal(predictive_mean, predictive_covar)
[docs] def kl_divergence(self) -> Tensor:
r"""
Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
and the prior inducing distribution :math:`p(\mathbf u)`.
:rtype: torch.Tensor
"""
if self._ngd():
if hasattr(self, "_memoize_cache") and "kl" in self._memoize_cache:
return self._memoize_cache["kl"]
else:
raise RuntimeError(
"KL divergence for NGD-CIQ should be computed during forward calls."
"This is probably a bug in GPyTorch."
)
else:
return super().kl_divergence()
def __call__(self, x: torch.Tensor, prior: bool = False, *params, **kwargs) -> MultivariateNormal:
# This is mostly the same as _VariationalStrategy.__call__()
# but with special rules for natural gradient descent (to prevent O(M^3) computation)
# If we're in prior mode, then we're done!
if prior:
return self.model.forward(x)
# Delete previously cached items from the training distribution
if self.training:
self._clear_cache()
# (Maybe) initialize variational distribution
if not self.variational_params_initialized.item():
if self._ngd():
noise = torch.randn_like(self.prior_distribution.mean).mul_(1e-3)
eye = torch.eye(noise.size(-1), dtype=noise.dtype, device=noise.device).mul(-0.5)
self._variational_distribution.natural_vec.data.copy_(noise)
self._variational_distribution.natural_mat.data.copy_(eye)
self.variational_params_initialized.fill_(1)
else:
prior_dist = self.prior_distribution
self._variational_distribution.initialize_variational_distribution(prior_dist)
self.variational_params_initialized.fill_(1)
# Ensure inducing_points and x are the same size
inducing_points = self.inducing_points
if inducing_points.shape[:-2] != x.shape[:-2]:
x, inducing_points = self._expand_inputs(x, inducing_points)
# Get q(f)
if self._ngd():
return Module.__call__(
self,
x,
inducing_points,
inducing_values=None,
variational_inducing_covar=None,
*params,
**kwargs,
)
else:
# Get p(u)/q(u)
variational_dist_u = self.variational_distribution
if isinstance(variational_dist_u, MultivariateNormal):
return Module.__call__(
self,
x,
inducing_points,
inducing_values=variational_dist_u.mean,
variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
**kwargs,
)
elif isinstance(variational_dist_u, Delta):
return Module.__call__(
self,
x,
inducing_points,
inducing_values=variational_dist_u.mean,
variational_inducing_covar=None,
**kwargs,
)
else:
raise RuntimeError(
f"Invalid variational distribuition ({type(variational_dist_u)}). "
"Expected a multivariate normal or a delta distribution."
)