#!/usr/bin/env python3
from typing import Any, Optional
import torch
from jaxtyping import Float
from linear_operator import to_dense
from linear_operator.operators import DiagLinearOperator, LinearOperator, TriangularLinearOperator
from linear_operator.utils.cholesky import psd_safe_cholesky
from torch import LongTensor, Tensor
from ..distributions import MultivariateNormal
from ..models import ApproximateGP, ExactGP
from ..module import Module
from ..utils.errors import CachingError
from ..utils.memoize import add_to_cache, cached, pop_from_cache
from ..utils.nearest_neighbors import NNUtil
from ._variational_distribution import _VariationalDistribution
from .mean_field_variational_distribution import MeanFieldVariationalDistribution
from .unwhitened_variational_strategy import UnwhitenedVariationalStrategy
[docs]class NNVariationalStrategy(UnwhitenedVariationalStrategy):
r"""
This strategy sets all inducing point locations to observed inputs,
and employs a :math:`k`-nearest-neighbor approximation. It was introduced as the
`Variational Nearest Neighbor Gaussian Processes (VNNGP)` in `Wu et al (2022)`_.
See the `VNNGP tutorial`_ for an example.
VNNGP assumes a k-nearest-neighbor generative process for inducing points :math:`\mathbf u`,
:math:`\mathbf q(\mathbf u) = \prod_{j=1}^M q(u_j | \mathbf u_{n(j)})`
where :math:`n(j)` denotes the indices of :math:`k` nearest neighbors for :math:`u_j` among
:math:`u_1, \cdots, u_{j-1}`. For any test observation :math:`\mathbf f`,
VNNGP makes predictive inference conditioned on its :math:`k` nearest inducing points
:math:`\mathbf u_{n(f)}`, i.e. :math:`p(f|\mathbf u_{n(f)})`.
VNNGP's objective factorizes over inducing points and observations, making stochastic optimization over both
immediately available. After a one-time cost of computing the :math:`k`-nearest neighbor structure,
the training and inference complexity is :math:`O(k^3)`.
Since VNNGP uses observations as inducing points, it is a user choice to either (1)
use the same mini-batch of inducing points and observations (recommended),
or (2) use different mini-batches of inducing points and observations. See the `VNNGP tutorial`_ for
implementation and comparison.
.. note::
The current implementation only supports :obj:`~gpytorch.variational.MeanFieldVariationalDistribution`.
We recommend installing the `faiss`_ library (requiring separate package installment)
for nearest neighbor search, which is significantly faster than the `scikit-learn` nearest neighbor search.
GPyTorch will automatically use `faiss` if it is installed, but will revert to `scikit-learn` otherwise.
Different inducing point orderings will produce in different nearest neighbor approximations.
:param ~gpytorch.models.ApproximateGP model: Model this strategy is applied to.
Typically passed in when the VariationalStrategy is created in the
__init__ method of the user defined model.
:param inducing_points: Tensor containing a set of inducing
points to use for variational inference.
:param variational_distribution: A
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
:param k: Number of nearest neighbors.
:param training_batch_size: The number of data points that will be in the training batch size.
:param jitter_val: Amount of diagonal jitter to add for covariance matrix numerical stability.
:param compute_full_kl: Whether to compute full kl divergence or stochastic estimate.
.. _Wu et al (2022):
https://arxiv.org/pdf/2202.01694.pdf
.. _VNNGP tutorial:
examples/04_Variational_and_Approximate_GPs/VNNGP.html
.. _faiss:
https://github.com/facebookresearch/faiss
"""
def __init__(
self,
model: ApproximateGP,
inducing_points: Float[Tensor, "... M D"],
variational_distribution: Float[_VariationalDistribution, "... M"],
k: int,
training_batch_size: Optional[int] = None,
jitter_val: Optional[float] = 1e-3,
compute_full_kl: Optional[bool] = False,
):
assert isinstance(
variational_distribution, MeanFieldVariationalDistribution
), "Currently, NNVariationalStrategy only supports MeanFieldVariationalDistribution."
super().__init__(
model, inducing_points, variational_distribution, learn_inducing_locations=False, jitter_val=jitter_val
)
# Model
object.__setattr__(self, "model", model)
self.inducing_points = inducing_points
self.M, self.D = inducing_points.shape[-2:]
self.k = k
assert self.k < self.M, (
f"Number of nearest neighbors k must be smaller than the number of inducing points, "
f"but got k = {k}, M = {self.M}."
)
self._inducing_batch_shape: torch.Size = inducing_points.shape[:-2]
self._model_batch_shape: torch.Size = self._variational_distribution.variational_mean.shape[:-1]
self._batch_shape: torch.Size = torch.broadcast_shapes(self._inducing_batch_shape, self._model_batch_shape)
self.nn_util: NNUtil = NNUtil(
k, dim=self.D, batch_shape=self._inducing_batch_shape, device=inducing_points.device
)
self._compute_nn()
# otherwise, no nearest neighbor approximation is used
self.training_batch_size = training_batch_size if training_batch_size is not None else self.M
self._set_training_iterator()
self.compute_full_kl = compute_full_kl
@property
@cached(name="prior_distribution_memo")
def prior_distribution(self) -> Float[MultivariateNormal, "... M"]:
out = self.model.forward(self.inducing_points)
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val))
return res
def _cholesky_factor(
self, induc_induc_covar: Float[LinearOperator, "... M M"]
) -> Float[TriangularLinearOperator, "... M M"]:
# Uncached version
L = psd_safe_cholesky(to_dense(induc_induc_covar))
return TriangularLinearOperator(L)
def __call__(
self, x: Float[Tensor, "... N D"], prior: bool = False, **kwargs: Any
) -> Float[MultivariateNormal, "... N"]:
# If we're in prior mode, then we're done!
if prior:
return self.model.forward(x, **kwargs)
if x is not None:
# Make sure x and inducing points have the same batch shape
if not (self.inducing_points.shape[:-2] == x.shape[:-2]):
try:
x = x.expand(*self.inducing_points.shape[:-2], *x.shape[-2:]).contiguous()
except RuntimeError:
raise RuntimeError(
f"x batch shape must match or broadcast with the inducing points' batch shape, "
f"but got x batch shape = {x.shape[:-2]}, "
f"inducing points batch shape = {self.inducing_points.shape[:-2]}."
)
# Delete previously cached items from the training distribution
if self.training:
self._clear_cache()
# (Maybe) initialize variational distribution
if not self.variational_params_initialized.item():
prior_dist = self.prior_distribution
self._variational_distribution.variational_mean.data.copy_(prior_dist.mean)
self._variational_distribution.variational_mean.data.add_(
torch.randn_like(prior_dist.mean), alpha=self._variational_distribution.mean_init_std
)
# initialize with a small variational stddev for quicker conv. of kl divergence
self._variational_distribution._variational_stddev.data.copy_(torch.tensor(1e-2))
self.variational_params_initialized.fill_(1)
return self.forward(
x, self.inducing_points, inducing_values=None, variational_inducing_covar=None, **kwargs
)
else:
# Ensure inducing_points and x are the same size
inducing_points = self.inducing_points
return self.forward(x, inducing_points, inducing_values=None, variational_inducing_covar=None, **kwargs)
def forward(
self,
x: Float[Tensor, "... N D"],
inducing_points: Float[Tensor, "... M D"],
inducing_values: Float[Tensor, "... M"],
variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None,
**kwargs: Any,
) -> Float[MultivariateNormal, "... N"]:
if self.training:
# In training mode, note that the full inducing points set = full training dataset
# Users have the option to choose input None or a tensor of training data for x
# If x is None, will sample training data from inducing points
# Otherwise, will find the indices of inducing points that are equal to x
if x is None:
x_indices = self._get_training_indices()
kl_indices = x_indices
predictive_mean = self._variational_distribution.variational_mean[..., x_indices]
predictive_var = self._variational_distribution._variational_stddev[..., x_indices] ** 2
else:
# find the indices of inducing points that correspond to x
x_indices = self.nn_util.find_nn_idx(x.float(), k=1).squeeze(-1) # (*inducing_batch_shape, batch_size)
expanded_x_indices = x_indices.expand(*self._batch_shape, x_indices.shape[-1])
expanded_variational_mean = self._variational_distribution.variational_mean.expand(
*self._batch_shape, self.M
)
expanded_variational_var = (
self._variational_distribution._variational_stddev.expand(*self._batch_shape, self.M) ** 2
)
predictive_mean = expanded_variational_mean.gather(-1, expanded_x_indices)
predictive_var = expanded_variational_var.gather(-1, expanded_x_indices)
# sample a different indices for stochastic estimation of kl
kl_indices = self._get_training_indices()
kl = self._kl_divergence(kl_indices)
add_to_cache(self, "kl_divergence_memo", kl)
return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var))
else:
nn_indices = self.nn_util.find_nn_idx(x.float())
x_batch_shape = x.shape[:-2]
batch_shape = torch.broadcast_shapes(self._batch_shape, x_batch_shape)
x_bsz = x.shape[-2]
assert nn_indices.shape == (*x_batch_shape, x_bsz, self.k), nn_indices.shape
# select K nearest neighbors from inducing points for test point x
expanded_nn_indices = nn_indices.unsqueeze(-1).expand(*x_batch_shape, x_bsz, self.k, self.D)
expanded_inducing_points = inducing_points.unsqueeze(-2).expand(*x_batch_shape, self.M, self.k, self.D)
inducing_points = expanded_inducing_points.gather(-3, expanded_nn_indices)
assert inducing_points.shape == (*x_batch_shape, x_bsz, self.k, self.D)
# get variational mean and covar for nearest neighbors
inducing_values = self._variational_distribution.variational_mean
expanded_inducing_values = inducing_values.unsqueeze(-1).expand(*batch_shape, self.M, self.k)
expanded_nn_indices = nn_indices.expand(*batch_shape, x_bsz, self.k)
inducing_values = expanded_inducing_values.gather(-2, expanded_nn_indices)
assert inducing_values.shape == (*batch_shape, x_bsz, self.k)
variational_stddev = self._variational_distribution._variational_stddev
assert variational_stddev.shape == (*self._model_batch_shape, self.M)
expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*batch_shape, self.M, self.k)
variational_inducing_covar = expanded_variational_stddev.gather(-2, expanded_nn_indices) ** 2
assert variational_inducing_covar.shape == (*batch_shape, x_bsz, self.k)
variational_inducing_covar = DiagLinearOperator(variational_inducing_covar)
assert variational_inducing_covar.shape == (*batch_shape, x_bsz, self.k, self.k)
# Make everything batch mode
x = x.unsqueeze(-2)
assert x.shape == (*x_batch_shape, x_bsz, 1, self.D)
x = x.expand(*batch_shape, x_bsz, 1, self.D)
# Compute forward mode in the standard way
_batch_dims = tuple(range(len(batch_shape)))
_x = x.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, 1, D)
# inducing_points.shape (*x_batch_shape, x_bsz, self.k, self.D)
inducing_points = inducing_points.expand(*batch_shape, x_bsz, self.k, self.D)
_inducing_points = inducing_points.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, k, D)
_inducing_values = inducing_values.permute((-2,) + _batch_dims + (-1,))
_variational_inducing_covar = variational_inducing_covar.permute((-3,) + _batch_dims + (-2, -1))
dist = super().forward(_x, _inducing_points, _inducing_values, _variational_inducing_covar, **kwargs)
_x_batch_dims = tuple(range(1, 1 + len(batch_shape)))
predictive_mean = dist.mean # (x_bsz, *x_batch_shape, 1)
predictive_covar = dist.covariance_matrix # (x_bsz, *x_batch_shape, 1, 1)
predictive_mean = predictive_mean.permute(_x_batch_dims + (0, -1))
predictive_covar = predictive_covar.permute(_x_batch_dims + (0, -2, -1))
# Undo batch mode
predictive_mean = predictive_mean.squeeze(-1)
predictive_var = predictive_covar.squeeze(-2).squeeze(-1)
assert predictive_var.shape == predictive_covar.shape[:-2]
assert predictive_mean.shape == predictive_covar.shape[:-2]
# Return the distribution
return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var))
def get_fantasy_model(
self,
inputs: Float[Tensor, "... N D"],
targets: Float[Tensor, "... N"],
mean_module: Optional[Module] = None,
covar_module: Optional[Module] = None,
**kwargs,
) -> ExactGP:
raise NotImplementedError(
f"No fantasy model support for {self.__class__.__name__}. "
"Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported."
)
def _set_training_iterator(self) -> None:
self._training_indices_iter = 0
if self.training_batch_size == self.M:
self._training_indices_iterator = (torch.arange(self.M, device=self.inducing_points.device),)
else:
# The first training batch always contains the first k inducing points
# This is because computing the KL divergence for the first k inducing points is special-cased
# (since the first k inducing points have < k neighbors)
# Note that there is a special function _firstk_kl_helper for this
training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k
self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size)
self._total_training_batches = len(self._training_indices_iterator)
def _get_training_indices(self) -> LongTensor:
self.current_training_indices = self._training_indices_iterator[self._training_indices_iter]
self._training_indices_iter += 1
if self._training_indices_iter == self._total_training_batches:
self._set_training_iterator()
return self.current_training_indices
def _firstk_kl_helper(self) -> Float[Tensor, "..."]:
# Compute the KL divergence for first k inducing points
train_x_firstk = self.inducing_points[..., : self.k, :]
full_output = self.model.forward(train_x_firstk)
induc_mean, induc_induc_covar = full_output.mean, full_output.lazy_covariance_matrix
induc_induc_covar = induc_induc_covar.add_jitter(self.jitter_val)
prior_dist = MultivariateNormal(induc_mean, induc_induc_covar)
inducing_values = self._variational_distribution.variational_mean[..., : self.k]
variational_covar_fisrtk = self._variational_distribution._variational_stddev[..., : self.k] ** 2
variational_inducing_covar = DiagLinearOperator(variational_covar_fisrtk)
variational_distribution = MultivariateNormal(inducing_values, variational_inducing_covar)
kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape
return kl
def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821
# Compute the KL divergence for a mini batch of the rest M-k inducing points
# See paper appendix for kl breakdown
kl_bs = len(kl_indices) # training_batch_size
variational_mean = self._variational_distribution.variational_mean # (*model_bs, M)
variational_stddev = self._variational_distribution._variational_stddev
# (1) compute logdet_q
inducing_point_log_variational_covar = (variational_stddev[..., kl_indices] ** 2).log()
logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) # model_bs
# (2) compute lodet_p
# Select a mini-batch of inducing points according to kl_indices
inducing_points = self.inducing_points[..., kl_indices, :].expand(*self._batch_shape, kl_bs, self.D)
# (*bs, kl_bs, D)
# Select their K nearest neighbors
nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device)
# (*bs, kl_bs, K)
expanded_inducing_points_all = self.inducing_points.unsqueeze(-2).expand(
*self._batch_shape, self.M, self.k, self.D
)
expanded_nearest_neighbor_indices = nearest_neighbor_indices.unsqueeze(-1).expand(
*self._batch_shape, kl_bs, self.k, self.D
)
nearest_neighbors = expanded_inducing_points_all.gather(-3, expanded_nearest_neighbor_indices)
# (*bs, kl_bs, K, D)
# Compute prior distribution
# Move the kl_bs dimension to the first dimension to enable batch covar_module computation
nearest_neighbors_ = nearest_neighbors.permute((-3,) + tuple(range(len(self._batch_shape))) + (-2, -1))
# (kl_bs, *bs, K, D)
inducing_points_ = inducing_points.permute((-2,) + tuple(range(len(self._batch_shape))) + (-1,))
# (kl_bs, *bs, D)
full_output = self.model.forward(torch.cat([nearest_neighbors_, inducing_points_.unsqueeze(-2)], dim=-2))
full_mean, full_covar = full_output.mean, full_output.covariance_matrix
# Mean terms
_undo_permute_dims = tuple(range(1, 1 + len(self._batch_shape))) + (0, -1)
nearest_neighbors_prior_mean = full_mean[..., : self.k].permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K)
inducing_prior_mean = full_mean[..., self.k :].permute(_undo_permute_dims).squeeze(-1) # (*inducing_bs, kl_bs)
# Covar terms
nearest_neighbors_prior_cov = full_covar[..., : self.k, : self.k]
nearest_neighbors_inducing_prior_cross_cov = full_covar[..., : self.k, self.k :]
inducing_prior_cov = full_covar[..., self.k :, self.k :]
inducing_prior_cov = (
inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,) + tuple(range(len(self._batch_shape))))
)
# Interpolation term K_nn^{-1} k_{nu}
interp_term = torch.linalg.solve(
nearest_neighbors_prior_cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device),
nearest_neighbors_inducing_prior_cross_cov,
).squeeze(
-1
) # (kl_bs, *inducing_bs, K)
interp_term = interp_term.permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K)
nearest_neighbors_inducing_prior_cross_cov = nearest_neighbors_inducing_prior_cross_cov.squeeze(-1).permute(
_undo_permute_dims
) # k_{n(j),j}, (*inducing_bs, kl_bs, K)
invquad_term_for_F = torch.sum(
interp_term * nearest_neighbors_inducing_prior_cross_cov, dim=-1
) # (*inducing_bs, kl_bs)
inducing_prior_cov = self.model.covar_module.forward(
inducing_points, inducing_points, diag=True
) # (*inducing_bs, kl_bs)
F = inducing_prior_cov - invquad_term_for_F
F = F + self.jitter_val
# K_uu - k_un K_nn^{-1} k_nu
logdet_p = F.log().sum(dim=-1) # shape: inducing_bs
# (3) compute trace_term
expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k)
expanded_variational_mean = variational_mean.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k)
expanded_nearest_neighbor_indices = nearest_neighbor_indices.expand(*self._batch_shape, kl_bs, self.k)
nearest_neighbor_variational_covar = (
expanded_variational_stddev.gather(-2, expanded_nearest_neighbor_indices) ** 2
) # (*batch_shape, kl_bs, k)
bjsquared_s_nearest_neighbors = torch.sum(
interp_term**2 * nearest_neighbor_variational_covar, dim=-1
) # (*batch_shape, kl_bs)
inducing_point_variational_covar = variational_stddev[..., kl_indices] ** 2 # (model_bs, kl_bs)
trace_term = (1.0 / F * (bjsquared_s_nearest_neighbors + inducing_point_variational_covar)).sum(
dim=-1
) # batch_shape
# (4) compute invquad_term
nearest_neighbors_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices)
Bj_m_nearest_neighbors = torch.sum(
interp_term * (nearest_neighbors_variational_mean - nearest_neighbors_prior_mean), dim=-1
)
inducing_variational_mean = variational_mean[..., kl_indices]
invquad_term = torch.sum(
(inducing_variational_mean - inducing_prior_mean - Bj_m_nearest_neighbors) ** 2 / F, dim=-1
)
kl = (logdet_p - logdet_q - kl_bs + trace_term + invquad_term) * (1.0 / 2)
assert kl.shape == self._batch_shape, kl.shape
return kl
def _kl_divergence(
self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None
) -> Float[Tensor, "..."]:
if self.compute_full_kl or (self._total_training_batches == 1):
if batch_size is None:
batch_size = self.training_batch_size
kl = self._firstk_kl_helper()
for kl_indices in torch.split(torch.arange(self.k, self.M), batch_size):
kl += self._stochastic_kl_helper(kl_indices)
else:
# compute a stochastic estimate
assert kl_indices is not None
if self._training_indices_iter == 1:
assert len(kl_indices) == self.k, (
f"kl_indices sould be the first batch data of length k, "
f"but got len(kl_indices) = {len(kl_indices)} and k = {self.k}."
)
kl = self._firstk_kl_helper() * self.M / self.k
else:
kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices)
return kl
def kl_divergence(self) -> Float[Tensor, "..."]:
try:
return pop_from_cache(self, "kl_divergence_memo")
except CachingError:
raise RuntimeError("KL Divergence of variational strategy was called before nearest neighbors were set.")
def _compute_nn(self) -> "NNVariationalStrategy":
with torch.no_grad():
inducing_points_fl = self.inducing_points.data.float()
self.nn_util.set_nn_idx(inducing_points_fl)
self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl)
# shape (*_inducing_batch_shape, M-k, k)
return self