# Source code for gpytorch.variational.ciq_variational_strategy

#!/usr/bin/env python3

from typing import Optional, Tuple

import torch
from linear_operator import to_linear_operator
from linear_operator.operators import DiagLinearOperator, MatmulLinearOperator, SumLinearOperator
from linear_operator.utils import linear_cg

from .. import settings
from ..distributions import Delta, MultivariateNormal
from ..module import Module
from ..utils.memoize import cached
from ._variational_strategy import _VariationalStrategy
from .natural_variational_distribution import NaturalVariationalDistribution

"""
This function takes in

- the kernel interpolation term K_ZZ^{-1/2} k_ZX
- the natural parameters of the variational distribution

and returns

- the predictive distribution mean/covariance
- the inducing KL divergence KL( q(u) || p(u))

However, the gradients will be with respect to the **cannonical parameters**
of the variational distribution, rather than the **natural parameters**.
This corresponds to performing natural gradient descent on the variational distribution.
"""

@staticmethod
def forward(
ctx, interp_term: torch.Tensor, natural_vec: torch.Tensor, natural_mat: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Compute precision
prec = natural_mat.mul(-2.0)
diag = prec.diagonal(dim1=-1, dim2=-2).unsqueeze(-1)

# Make sure that interp_term and natural_vec are the same batch shape
batch_shape = torch.broadcast_shapes(interp_term.shape[:-2], natural_vec.shape[:-1])
expanded_interp_term = interp_term.expand(*batch_shape, *interp_term.shape[-2:])
expanded_natural_vec = natural_vec.expand(*batch_shape, natural_vec.size(-1))

# Compute necessary solves with the precision. We need
# m = expec_vec = S * natural_vec
# S K^{-1/2} k
solves = linear_cg(
prec.matmul,
torch.cat([expanded_natural_vec.unsqueeze(-1), expanded_interp_term], dim=-1),
n_tridiag=0,
max_iter=settings.max_cg_iterations.value(),
tolerance=min(settings.eval_cg_tolerance.value(), settings.cg_tolerance.value()),
preconditioner=lambda x: x / diag,
)
expec_vec = solves[..., 0]
s_times_interp_term = solves[..., 1:]

# Compute the interpolated mean
# k^T K^{-1/2} m
interp_mean = (s_times_interp_term.transpose(-1, -2) @ natural_vec.unsqueeze(-1)).squeeze(-1)

# Compute the interpolated variance
# k^T K^{-1/2} S K^{-1/2} k = k^T K^{-1/2} (expec_mat - expec_vec expec_vec^T) K^{-1/2} k
interp_var = (s_times_interp_term * interp_term).sum(dim=-2)

# Let's not bother actually computing the KL-div in the foward pass
# 1/2 ( -log | S | + tr(S) + m^T m - len(m) )
# = 1/2 ( -log | expec_mat - expec_vec expec_vec^T | + tr(expec_mat) - len(m) )
kl_div = torch.zeros_like(interp_mean[..., 0])

# We're done!
ctx.save_for_backward(interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec)
return interp_mean, interp_var, kl_div

@staticmethod
def backward(
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Get the saved terms
interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec = ctx.saved_tensors

# Expand data-depenedent gradients

# Compute gradient of interp term (K^{-1/2} k)
# interp_mean component: m
# interp_var component: S K^{-1/2} k
# kl component: 0
interp_term_grad = (interp_var_grad * s_times_interp_term).mul(2.0) + (
)

# Compute gradient of expected vector (m)
# interp_mean component: K^{-1/2} k
# interp_var component: (k^T K^{-1/2} m) K^{-1/2} k
# kl component: S^{-1} m
[
(interp_var_grad * interp_mean.unsqueeze(-2) * interp_term).sum(dim=-1).mul(-2),
]
)

# Compute gradient of expected matrix (mm^T + S)
# interp_mean component: 0
# interp_var component: K^{-1/2} k k^T K^{-1/2}
# kl component: 1/2 ( I - S^{-1} )
eye = torch.eye(expec_vec.size(-1), device=expec_vec.device, dtype=expec_vec.dtype)
(interp_var_grad * interp_term) @ interp_term.transpose(-1, -2),
(kl_div_grad.unsqueeze(-1).unsqueeze(-1) * (eye - prec).mul(0.5)),
)

# We're done!
return interp_term_grad, expec_vec_grad, expec_mat_grad, None  # Extra "None" for the kwarg

[docs]class CiqVariationalStrategy(_VariationalStrategy):
r"""
Similar to :class:~gpytorch.variational.VariationalStrategy,
except the whitening operation is performed using Contour Integral Quadrature
rather than Cholesky (see Pleiss et al. (2020)_ for more info).
See the CIQ-SVGP tutorial_ for an example.

Contour Integral Quadrature uses iterative matrix-vector multiplication to approximate
the :math:\mathbf K_{\mathbf Z \mathbf Z}^{-1/2} matrix used for the whitening operation.
This can be more efficient than the standard variational strategy for large numbers
of inducing points (e.g. :math:M > 1000) or when the inducing points have structure
(e.g. they lie on an evenly-spaced grid).

.. note::

It is recommended that this object is used in conjunction with
:obj:~gpytorch.variational.NaturalVariationalDistribution and
natural gradient descent_.

:param ~gpytorch.models.ApproximateGP model: Model this strategy is applied to.
Typically passed in when the VariationalStrategy is created in the
__init__ method of the user defined model.
:param torch.Tensor inducing_points: Tensor containing a set of inducing
points to use for variational inference.
:param ~gpytorch.variational.VariationalDistribution variational_distribution: A
VariationalDistribution object that represents the form of the variational distribution :math:q(\mathbf u)
:param learn_inducing_locations: (Default True): Whether or not
the inducing point locations :math:\mathbf Z should be learned (i.e. are they
parameters of the model).
:type learn_inducing_locations: bool, optional

.. _Pleiss et al. (2020):
https://arxiv.org/pdf/2006.11267.pdf
.. _CIQ-SVGP tutorial:
examples/04_Variational_and_Approximate_GPs/SVGP_CIQ.html
.. _natural gradient descent:
"""

def _ngd(self):
return isinstance(self._variational_distribution, NaturalVariationalDistribution)

@property
@cached(name="prior_distribution_memo")
def prior_distribution(self):
zeros = torch.zeros(
self._variational_distribution.shape(),
dtype=self._variational_distribution.dtype,
device=self._variational_distribution.device,
)
ones = torch.ones_like(zeros)
res = MultivariateNormal(zeros, DiagLinearOperator(ones))
return res

@property
@cached(name="variational_distribution_memo")
def variational_distribution(self):
if self._ngd():
raise RuntimeError(
"Variational distribution for NGD-CIQ should be computed during forward calls. "
"This is probably a bug in GPyTorch."
)
return super().variational_distribution

def forward(
self,
x: torch.Tensor,
inducing_points: torch.Tensor,
inducing_values: torch.Tensor,
variational_inducing_covar: Optional[MultivariateNormal] = None,
**kwargs,
) -> MultivariateNormal:
# Compute full prior distribution
full_inputs = torch.cat([inducing_points, x], dim=-2)
full_output = self.model.forward(full_inputs)
full_covar = full_output.lazy_covariance_matrix

# Covariance terms
num_induc = inducing_points.size(-2)
test_mean = full_output.mean[..., num_induc:]
induc_induc_covar = full_covar[..., :num_induc, :num_induc].evaluate_kernel().add_jitter(1e-2)
induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
data_data_covar = full_covar[..., num_induc:, num_induc:].add_jitter(1e-4)

# Compute interpolation terms
# K_XZ K_ZZ^{-1} \mu_z
# K_XZ K_ZZ^{-1/2} \mu_Z
with settings.max_preconditioner_size(0):  # Turn off preconditioning for CIQ
interp_term = to_linear_operator(induc_induc_covar).sqrt_inv_matmul(induc_data_covar)

# Compute interpolated mean and variance terms
# We have separate computation rules for NGD versus standard GD
if self._ngd():
interp_mean, interp_var, kl_div = _NgdInterpTerms.apply(
interp_term,
self._variational_distribution.natural_vec,
self._variational_distribution.natural_mat,
)

# Compute the covariance of q(f)
predictive_var = data_data_covar.diagonal(dim1=-1, dim2=-2) - interp_term.pow(2).sum(dim=-2) + interp_var
predictive_var = torch.clamp_min(predictive_var, settings.min_variance.value(predictive_var.dtype))
predictive_covar = DiagLinearOperator(predictive_var)

# Also compute and cache the KL divergence
if not hasattr(self, "_memoize_cache"):
self._memoize_cache = dict()
self._memoize_cache["kl"] = kl_div

else:
# Compute interpolated mean term
interp_mean = torch.matmul(
interp_term.transpose(-1, -2), (inducing_values - self.prior_distribution.mean).unsqueeze(-1)
).squeeze(-1)

# Compute the covariance of q(f)
middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
if variational_inducing_covar is not None:
middle_term = SumLinearOperator(variational_inducing_covar, middle_term)
predictive_covar = SumLinearOperator(
MatmulLinearOperator(interp_term.transpose(-1, -2), middle_term @ interp_term),
)

# Compute the mean of q(f)
# k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
predictive_mean = interp_mean + test_mean

# Return the distribution
return MultivariateNormal(predictive_mean, predictive_covar)

[docs]    def kl_divergence(self):
"""
Compute the KL divergence between the variational inducing distribution :math:q(\mathbf u)
and the prior inducing distribution :math:p(\mathbf u).

:rtype: torch.Tensor
"""
if self._ngd():
if hasattr(self, "_memoize_cache") and "kl" in self._memoize_cache:
return self._memoize_cache["kl"]
else:
raise RuntimeError(
"KL divergence for NGD-CIQ should be computed during forward calls."
"This is probably a bug in GPyTorch."
)
else:
return super().kl_divergence()

def __call__(self, x: torch.Tensor, prior: bool = False, **kwargs) -> MultivariateNormal:
# This is mostly the same as _VariationalStrategy.__call__()
# but with special rules for natural gradient descent (to prevent O(M^3) computation)

# If we're in prior mode, then we're done!
if prior:
return self.model.forward(x)

# Delete previously cached items from the training distribution
if self.training:
self._clear_cache()

# (Maybe) initialize variational distribution
if not self.variational_params_initialized.item():
if self._ngd():
noise = torch.randn_like(self.prior_distribution.mean).mul_(1e-3)
eye = torch.eye(noise.size(-1), dtype=noise.dtype, device=noise.device).mul(-0.5)
self._variational_distribution.natural_vec.data.copy_(noise)
self._variational_distribution.natural_mat.data.copy_(eye)
self.variational_params_initialized.fill_(1)
else:
prior_dist = self.prior_distribution
self._variational_distribution.initialize_variational_distribution(prior_dist)
self.variational_params_initialized.fill_(1)

# Ensure inducing_points and x are the same size
inducing_points = self.inducing_points
if inducing_points.shape[:-2] != x.shape[:-2]:
x, inducing_points = self._expand_inputs(x, inducing_points)

# Get q(f)
if self._ngd():
return Module.__call__(
self,
x,
inducing_points,
inducing_values=None,
variational_inducing_covar=None,
**kwargs,
)
else:
# Get p(u)/q(u)
variational_dist_u = self.variational_distribution

if isinstance(variational_dist_u, MultivariateNormal):
return Module.__call__(
self,
x,
inducing_points,
inducing_values=variational_dist_u.mean,
variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
**kwargs,
)
elif isinstance(variational_dist_u, Delta):
return Module.__call__(
self,
x,
inducing_points,
inducing_values=variational_dist_u.mean,
variational_inducing_covar=None,
ngd=False,
**kwargs,
)
else:
raise RuntimeError(
f"Invalid variational distribuition ({type(variational_dist_u)}). "
"Expected a multivariate normal or a delta distribution."
)