Source code for gpytorch.variational.variational_strategy

#!/usr/bin/env python3

from __future__ import annotations

import warnings
from collections.abc import Iterable
from typing import Any

import torch
from linear_operator import to_dense
from linear_operator.operators import (
    CholLinearOperator,
    DiagLinearOperator,
    LinearOperator,
    MatmulLinearOperator,
    RootLinearOperator,
    SumLinearOperator,
    TriangularLinearOperator,
)
from linear_operator.utils.cholesky import psd_safe_cholesky
from linear_operator.utils.errors import NotPSDError
from torch import Tensor

from gpytorch import settings

from gpytorch.variational._variational_strategy import _VariationalStrategy
from gpytorch.variational.cholesky_variational_distribution import CholeskyVariationalDistribution

from ..distributions import MultivariateNormal
from ..models import ApproximateGP
from ..settings import _linalg_dtype_cholesky, trace_mode
from ..utils.errors import CachingError
from ..utils.memoize import cached, clear_cache_hook, pop_from_cache_ignore_args
from ..utils.warnings import OldVersionWarning
from . import _VariationalDistribution


def _ensure_updated_strategy_flag_set(
    state_dict: dict[str, Tensor],
    prefix: str,
    local_metadata: dict[str, Any],
    strict: bool,
    missing_keys: Iterable[str],
    unexpected_keys: Iterable[str],
    error_msgs: Iterable[str],
):
    device = state_dict[list(state_dict.keys())[0]].device
    if prefix + "updated_strategy" not in state_dict:
        state_dict[prefix + "updated_strategy"] = torch.tensor(False, device=device)
        warnings.warn(
            "You have loaded a variational GP model (using `VariationalStrategy`) from a previous version of "
            "GPyTorch. We have updated the parameters of your model to work with the new version of "
            "`VariationalStrategy` that uses whitened parameters.\nYour model will work as expected, but we "
            "recommend that you re-save your model.",
            OldVersionWarning,
        )


class ComputePredictiveUpdates(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        chol: Tensor,
        induc_data_covar: Tensor,
        middle: Tensor,
        inducing_values: Tensor,
    ) -> tuple[Tensor, Tensor]:
        r"""Compute the predictive mean and variance updates as in `VariationalStrategy._compute_predictive_updates`.

        This function doesn't compute the updates to the off-diagonal entries in the predictive covariance. Only the
        variance update is computed.
        """
        interp_term = torch.linalg.solve_triangular(chol, induc_data_covar, upper=False)

        mean_update = (interp_term.mT @ inducing_values.unsqueeze(-1)).squeeze(-1)
        variance_update = torch.sum(interp_term.mT * (interp_term.mT @ middle), dim=-1)

        # NOTE: The backward call does not need `induc_data_covar`. Access to it is always through `interp_term`.
        ctx.save_for_backward(chol, interp_term, middle, inducing_values)

        return mean_update, variance_update

    @staticmethod
    def backward(
        ctx,
        d_mean: Tensor,
        d_variance: Tensor,
    ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        r"""A custom backward pass more efficient than PyTorch's autograd by rearranging tensor operations.

        This backward is bottlenecked by two O(m^2 n) matmuls whre `m` is the number of inducing points and `n` is the
        number of data points. In contrast, PyTorch's backward pass would require three O(m^2 n) matmuls and a O(m^2 n)
        triangular solve. Thus, this implementation is about 2x faster when `m << n`.
        """
        chol, interp_term, middle, inducing_values = ctx.saved_tensors

        # Common terms that will be used more than once
        interp_term_times_dmean = interp_term @ d_mean.unsqueeze(-1)
        interp_term_scaled_dvariance = interp_term * d_variance.unsqueeze(-2)  # K_ZZ^{-1/2} K_ZX @ diag(d_variance)

        # `K_ZZ^{-1/2} @ (S - I)`
        # NOTE: Empirically, the triangular solve against `S - I` still seems to be stable in FP32. However, hitting
        # `S - I` twice `K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2}` would be numerically unstable, which should be avoided.
        inv_chol_times_middle = torch.linalg.solve_triangular(chol.mT, middle, upper=True)

        # `K_ZZ^{-1/2} m`
        inv_chol_times_inducing_values = torch.linalg.solve_triangular(
            chol.mT, inducing_values.unsqueeze(-1), upper=True
        )

        # The derivative of `S - I` from the variance
        d_middle = interp_term_scaled_dvariance @ interp_term.mT

        # The derivative of `K_XZ K_ZZ^{-1/2} m` with respect to `m`
        d_inducing_values = interp_term_times_dmean.squeeze(-1)

        # The derivative of `K_XZ` received from the predictive variance. There is a factor of 2 because `K_XZ` appears
        # twice in the predictive variance and we exploit symmetry.
        d_induc_data_covar = 2.0 * inv_chol_times_middle @ interp_term_scaled_dvariance

        # Then add derivative of `K_XZ` received from the predictive mean: `K_ZZ^{-1/2} @ m @ dm^T`
        d_induc_data_covar = d_induc_data_covar + inv_chol_times_inducing_values @ d_mean.unsqueeze(-2)

        # The derivative of `K_ZZ^{-1/2}` received from the predictive variance. Again, we exploit symmetry here since
        # `K_ZZ^{-1/2}` appears twice.
        d_chol = -2.0 * inv_chol_times_middle @ d_middle

        # Then add the derivative of `K_ZZ^{-1/2}` received from the predictive mean
        d_chol = d_chol - inv_chol_times_inducing_values @ interp_term_times_dmean.mT

        # NOTE: In principle, we need to zero out the lower triangular part because `chol` is lower triangular. It is
        # actually not necessary here, because `d_chol` is immediately fed into `cholesky_backward`, which does not
        # care about the upper triangular part. We keep it here for consistency with PyTorch's implementation.
        # https://github.com/pytorch/pytorch/blob/4a0693682a8574bdc36e1ca2ea7bd2ddf5c19340/torch/csrc/autograd/FunctionsManual.cpp#L1999-L2003
        # NOTE: If we want to get fancy, fusing this backward with `cholesky_backward` will save a matmul. It may not
        # be worth the effort. It's only useful when there are more inducing points than the data.
        d_chol = d_chol.tril()

        return d_chol, d_induc_data_covar, d_middle, d_inducing_values


[docs]class VariationalStrategy(_VariationalStrategy): r""" The standard variational strategy, as defined by `Hensman et al. (2015)`_. This strategy takes a set of :math:`m \ll n` inducing points :math:`\mathbf Z` and applies an approximate distribution :math:`q( \mathbf u)` over their function values. (Here, we use the common notation :math:`\mathbf u = f(\mathbf Z)`. The approximate function distribution for any abitrary input :math:`\mathbf X` is given by: .. math:: q( f(\mathbf X) ) = \int p( f(\mathbf X) \mid \mathbf u) q(\mathbf u) \: d\mathbf u This variational strategy uses "whitening" to accelerate the optimization of the variational parameters. See `Matthews (2017)`_ for more info. :param model: Model this strategy is applied to. Typically passed in when the VariationalStrategy is created in the __init__ method of the user defined model. :param inducing_points: Tensor containing a set of inducing points to use for variational inference. :param variational_distribution: A VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)` :param learn_inducing_locations: (Default True): Whether or not the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they parameters of the model). :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability .. _Hensman et al. (2015): http://proceedings.mlr.press/v38/hensman15.pdf .. _Matthews (2017): https://www.repository.cam.ac.uk/handle/1810/278022 """ def __init__( self, model: ApproximateGP, inducing_points: Tensor, variational_distribution: _VariationalDistribution, learn_inducing_locations: bool = True, jitter_val: float | None = None, ): super().__init__( model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val=jitter_val ) self.register_buffer("updated_strategy", torch.tensor(True)) self._register_load_state_dict_pre_hook(_ensure_updated_strategy_flag_set) self.has_fantasy_strategy = True @cached(name="cholesky_factor", ignore_args=True) def _cholesky_factor(self, induc_induc_covar: LinearOperator) -> TriangularLinearOperator: L = psd_safe_cholesky(to_dense(induc_induc_covar).type(_linalg_dtype_cholesky.value())) return TriangularLinearOperator(L) @property @cached(name="prior_distribution_memo") def prior_distribution(self) -> MultivariateNormal: zeros = torch.zeros( self._variational_distribution.shape(), dtype=self._variational_distribution.dtype, device=self._variational_distribution.device, ) ones = torch.ones_like(zeros) res = MultivariateNormal(zeros, DiagLinearOperator(ones)) return res @property @cached(name="pseudo_points_memo") def pseudo_points(self) -> tuple[Tensor, Tensor]: # TODO: have var_mean, var_cov come from a method of _variational_distribution # while having Kmm_root be a root decomposition to enable CIQVariationalDistribution support. # retrieve the variational mean, m and covariance matrix, S. if not isinstance(self._variational_distribution, CholeskyVariationalDistribution): raise NotImplementedError( "Only CholeskyVariationalDistribution has pseudo-point support currently, " f"but your _variational_distribution is a {type(self._variational_distribution).__name__}" ) var_cov_root = TriangularLinearOperator(self._variational_distribution.chol_variational_covar) var_cov = CholLinearOperator(var_cov_root) var_mean = self.variational_distribution.mean if var_mean.shape[-1] != 1: var_mean = var_mean.unsqueeze(-1) # compute R = I - S cov_diff = var_cov.add_jitter(-1.0) cov_diff = -1.0 * cov_diff # K^{1/2} Kmm = self.model.covar_module(self.inducing_points) Kmm_root = Kmm.cholesky() # D_a = (S^{-1} - K^{-1})^{-1} = S + S R^{-1} S # note that in the whitened case R = I - S, unwhitened R = K - S # we compute (R R^{T})^{-1} R^T S for stability reasons as R is probably not PSD. eval_var_cov = var_cov.to_dense() eval_rhs = cov_diff.transpose(-1, -2).matmul(eval_var_cov) inner_term = cov_diff.matmul(cov_diff.transpose(-1, -2)) # TODO: flag the jitter here inner_solve = inner_term.add_jitter(self.jitter_val).solve(eval_rhs, eval_var_cov.transpose(-1, -2)) inducing_covar = var_cov + inner_solve inducing_covar = Kmm_root.matmul(inducing_covar).matmul(Kmm_root.transpose(-1, -2)) # mean term: D_a S^{-1} m # unwhitened: (S - S R^{-1} S) S^{-1} m = (I - S R^{-1}) m rhs = cov_diff.transpose(-1, -2).matmul(var_mean) # TODO: this jitter too inner_rhs_mean_solve = inner_term.add_jitter(self.jitter_val).solve(rhs) pseudo_target_mean = Kmm_root.matmul(inner_rhs_mean_solve) # ensure inducing covar is psd # TODO: make this be an explicit root decomposition try: pseudo_target_covar = CholLinearOperator(inducing_covar.add_jitter(self.jitter_val).cholesky()).to_dense() except NotPSDError: from linear_operator.operators import DiagLinearOperator evals, evecs = torch.linalg.eigh(inducing_covar) pseudo_target_covar = ( evecs.matmul(DiagLinearOperator(evals + self.jitter_val)).matmul(evecs.transpose(-1, -2)).to_dense() ) return pseudo_target_covar, pseudo_target_mean def _compute_predictive_updates( self, chol: LinearOperator, induc_data_covar: Tensor, inducing_values: Tensor, variational_inducing_covar: LinearOperator | None, prior_covar: LinearOperator, diag: bool = True, ) -> tuple[Tensor, LinearOperator]: r"""Compute the predictive mean and covariance updates. Adding the return values of this method to the prior mean and covariance yields the predictive mean and covariance. The predictive mean update is `K_{XZ} K_{ZZ}^{-1/2} m`. The predictive covariance update is `K_{XZ} K_{ZZ}^{-1/2} (S - I) K_{ZZ}^{-1/2} K_{ZX}`. :param chol: The Cholesky factor `K_{ZZ}^{-1/2}`. :param induc_data_covar: The covariance between the inducing points and the data `K_{ZX}`. :param inducing_values: The whitened variational mean `m`. :param variational_inducing_covar: The variational covariance `S`. :param prior_covar: The prior covariance, typically an identity matrix `I`. :param diag: If true, this method computes the predictive variance instead of the full covariance in train mode if there are more data than inducing points. :return: The predictive mean update and the predictive covariance update. """ middle_term = prior_covar.mul(-1) if variational_inducing_covar is not None: middle_term = SumLinearOperator(variational_inducing_covar, middle_term) # `S - I` # The custom autograd function doesn't compute the off-diagonal entries. Besides, it's only optimized for the # setting where the batch size is larger than the number of inducing points. if diag and self.training and induc_data_covar.size(-2) < induc_data_covar.size(-1): predictive_mean_update, predictive_variance_update = ComputePredictiveUpdates.apply( chol.to_dense().type(induc_data_covar.dtype), induc_data_covar, middle_term.to_dense(), inducing_values, ) return predictive_mean_update, DiagLinearOperator(predictive_variance_update) # The eval mode uses the same implementation as v1.14.3 else: # NOTE: `torch.linalg.solve_triangular(A, B)` seems to support mixed precision solve when `A` and `B` have # different dtypes. `B` is likely cast to the dtype of `A` internally. Thus, there is no need for explicit # type casting. Removing the type casting would be slightly faster and avoid memory allocation. However, we # keep the explicit type casting here because this behavior is not documented on the PyTorch side. interp_term = chol.solve(induc_data_covar.type(chol.dtype)) interp_term = interp_term.type(induc_data_covar.dtype) # Compute the predictive mean update # k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) predictive_mean_update = (interp_term.mT @ inducing_values.unsqueeze(-1)).squeeze(-1) if settings.trace_mode.on(): middle_term = middle_term.to_dense() # Compute the predictive covariance update # k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX predictive_covar_update = MatmulLinearOperator(interp_term.mT, middle_term @ interp_term) return predictive_mean_update, predictive_covar_update def forward( self, x: Tensor, inducing_points: Tensor, inducing_values: Tensor, variational_inducing_covar: LinearOperator | None = None, diag: bool = True, **kwargs, ) -> MultivariateNormal: # Compute full prior distribution full_inputs = torch.cat([inducing_points, x], dim=-2) full_output = self.model.forward(full_inputs, **kwargs) full_covar = full_output.lazy_covariance_matrix # Covariance terms num_induc = inducing_points.size(-2) test_mean = full_output.mean[..., num_induc:] induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter(self.jitter_val) induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense() data_data_covar = full_covar[..., num_induc:, num_induc:] # Compute interpolation terms # K_ZZ^{-1/2} K_ZX # K_ZZ^{-1/2} \mu_Z L = self._cholesky_factor(induc_induc_covar) if L.shape != induc_induc_covar.shape: # Aggressive caching can cause nasty shape incompatibilies when evaluating with different batch shapes # TODO: Use a hook fo this try: pop_from_cache_ignore_args(self, "cholesky_factor") except CachingError: pass L = self._cholesky_factor(induc_induc_covar) mean_update, covar_update = self._compute_predictive_updates( chol=L, induc_data_covar=induc_data_covar, inducing_values=inducing_values, variational_inducing_covar=variational_inducing_covar, prior_covar=self.prior_distribution.lazy_covariance_matrix, diag=diag, ) predictive_mean = test_mean + mean_update predictive_covar = SumLinearOperator(data_data_covar.add_jitter(self.jitter_val), covar_update) if trace_mode.on(): predictive_covar = predictive_covar.to_dense() return MultivariateNormal(predictive_mean, predictive_covar) def __call__(self, x: Tensor, prior: bool = False, diag: bool = True, **kwargs) -> MultivariateNormal: if not self.updated_strategy.item() and not prior: with torch.no_grad(): # Get unwhitened p(u). Whitening needs the full covariance. prior_function_dist = self(self.inducing_points, prior=True, diag=False) prior_mean = prior_function_dist.loc L = self._cholesky_factor(prior_function_dist.lazy_covariance_matrix.add_jitter(self.jitter_val)) # Temporarily turn off noise that's added to the mean orig_mean_init_std = self._variational_distribution.mean_init_std self._variational_distribution.mean_init_std = 0.0 # Change the variational parameters to be whitened variational_dist = self.variational_distribution if isinstance(variational_dist, MultivariateNormal): mean_diff = (variational_dist.loc - prior_mean).unsqueeze(-1).type(_linalg_dtype_cholesky.value()) whitened_mean = L.solve(mean_diff).squeeze(-1).to(variational_dist.loc.dtype) covar_root = variational_dist.lazy_covariance_matrix.root_decomposition().root.to_dense() covar_root = covar_root.type(_linalg_dtype_cholesky.value()) whitened_covar = RootLinearOperator(L.solve(covar_root).to(variational_dist.loc.dtype)) whitened_variational_distribution = variational_dist.__class__(whitened_mean, whitened_covar) self._variational_distribution.initialize_variational_distribution( whitened_variational_distribution ) # Reset the random noise parameter of the model self._variational_distribution.mean_init_std = orig_mean_init_std # Reset the cache clear_cache_hook(self) # Mark that we have updated the variational strategy self.updated_strategy.fill_(True) return super().__call__(x, prior=prior, diag=diag, **kwargs)