Source code for gpytorch.variational.batch_decoupled_variational_strategy

#!/usr/bin/env python3

from typing import Optional, Tuple

import torch
from linear_operator.operators import LinearOperator, MatmulLinearOperator, SumLinearOperator
from torch import Tensor
from torch.distributions.kl import kl_divergence

from ..distributions import Delta, MultivariateNormal
from ..models import ApproximateGP
from ..utils.errors import CachingError
from ..utils.memoize import pop_from_cache_ignore_args
from ._variational_distribution import _VariationalDistribution
from .delta_variational_distribution import DeltaVariationalDistribution
from .variational_strategy import VariationalStrategy


[docs]class BatchDecoupledVariationalStrategy(VariationalStrategy): r""" A VariationalStrategy that uses a different set of inducing points for the variational mean and variational covar. It follows the "decoupled" model proposed by `Jankowiak et al. (2020)`_ (which is roughly based on the strategies proposed by `Cheng et al. (2017)`_. Let :math:`\mathbf Z_\mu` and :math:`\mathbf Z_\sigma` be the mean/variance inducing points. The variational distribution for an input :math:`\mathbf x` is given by: .. math:: \begin{align*} \mathbb E[ f(\mathbf x) ] &= \mathbf k_{\mathbf Z_\mu \mathbf x}^\top \mathbf K_{\mathbf Z_\mu \mathbf Z_\mu}^{-1} \mathbf m \\ \text{Var}[ f(\mathbf x) ] &= k_{\mathbf x \mathbf x} - \mathbf k_{\mathbf Z_\sigma \mathbf x}^\top \mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1} \left( \mathbf K_{\mathbf Z_\sigma} - \mathbf S \right) \mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1} \mathbf k_{\mathbf Z_\sigma \mathbf x} \end{align*} where :math:`\mathbf m` and :math:`\mathbf S` are the variational parameters. Unlike the original proposed implementation, :math:`\mathbf Z_\mu` and :math:`\mathbf Z_\sigma` have **the same number of inducing points**, which allows us to perform batched operations. Additionally, you can use a different set of kernel hyperparameters for the mean and the variance function. We recommend using this feature only with the :obj:`~gpytorch.mlls.PredictiveLogLikelihood` objective function as proposed in "Parametric Gaussian Process Regressors" (`Jankowiak et al. (2020)`_). Use the mean_var_batch_dim to indicate which batch dimension corresponds to the different mean/var kernels. .. note:: We recommend using the "right-most" batch dimension (i.e. ``mean_var_batch_dim=-1``) for the dimension that corresponds to the different mean/variance kernel parameters. Assuming you want `b1` many independent GPs, the :obj:`~gpytorch.variational._VariationalDistribution` objects should have a batch shape of `b1`, and the mean/covar modules of the GP should have a batch shape of `b1 x 2`. (The 2 corresponds to the mean/variance hyperparameters.) .. seealso:: :obj:`~gpytorch.variational.OrthogonallyDecoupledVariationalStrategy` (a variant proposed by `Salimbeni et al. (2018)`_ that uses orthogonal projections.) :param model: Model this strategy is applied to. Typically passed in when the VariationalStrategy is created in the __init__ method of the user defined model. :param inducing_points: Tensor containing a set of inducing points to use for variational inference. :param variational_distribution: A VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)` :param learn_inducing_locations: (Default True): Whether or not the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they parameters of the model). :param mean_var_batch_dim: (Default `None`): Set this parameter (ideally to `-1`) to indicate which dimension corresponds to different kernel hyperparameters for the mean/variance functions. :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability .. _Cheng et al. (2017): https://arxiv.org/abs/1711.10127 .. _Salimbeni et al. (2018): https://arxiv.org/abs/1809.08820 .. _Jankowiak et al. (2020): https://arxiv.org/abs/1910.07123 Example (**different** hypers for mean/variance): >>> class MeanFieldDecoupledModel(gpytorch.models.ApproximateGP): >>> ''' >>> A batch of 3 independent MeanFieldDecoupled PPGPR models. >>> ''' >>> def __init__(self, inducing_points): >>> # The variational parameters have a batch_shape of [3] >>> variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution( >>> inducing_points.size(-1), batch_shape=torch.Size([3]), >>> ) >>> variational_strategy = gpytorch.variational.BatchDecoupledVariationalStrategy( >>> self, inducing_points, variational_distribution, learn_inducing_locations=True, >>> mean_var_batch_dim=-1 >>> ) >>> >>> # The mean/covar modules have a batch_shape of [3, 2] >>> # where the last batch dim corresponds to the mean & variance hyperparameters >>> super().__init__(variational_strategy) >>> self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([3, 2])) >>> self.covar_module = gpytorch.kernels.ScaleKernel( >>> gpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 2])), >>> batch_shape=torch.Size([3, 2]), >>> ) Example (**shared** hypers for mean/variance): >>> class MeanFieldDecoupledModel(gpytorch.models.ApproximateGP): >>> ''' >>> A batch of 3 independent MeanFieldDecoupled PPGPR models. >>> ''' >>> def __init__(self, inducing_points): >>> # The variational parameters have a batch_shape of [3] >>> variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution( >>> inducing_points.size(-1), batch_shape=torch.Size([3]), >>> ) >>> variational_strategy = gpytorch.variational.BatchDecoupledVariationalStrategy( >>> self, inducing_points, variational_distribution, learn_inducing_locations=True, >>> ) >>> >>> # The mean/covar modules have a batch_shape of [3, 1] >>> # where the singleton dimension corresponds to the shared mean/variance hyperparameters >>> super().__init__(variational_strategy) >>> self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([3, 1])) >>> self.covar_module = gpytorch.kernels.ScaleKernel( >>> gpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 1])), >>> batch_shape=torch.Size([3, 1]), >>> ) """ def __init__( self, model: ApproximateGP, inducing_points: Tensor, variational_distribution: _VariationalDistribution, learn_inducing_locations: bool = True, mean_var_batch_dim: Optional[int] = None, jitter_val: Optional[float] = None, ): if isinstance(variational_distribution, DeltaVariationalDistribution): raise NotImplementedError( "BatchDecoupledVariationalStrategy does not work with DeltaVariationalDistribution" ) if mean_var_batch_dim is not None and mean_var_batch_dim >= 0: raise ValueError(f"mean_var_batch_dim should be negative indexed, got {mean_var_batch_dim}") self.mean_var_batch_dim = mean_var_batch_dim # Maybe unsqueeze inducing points if inducing_points.dim() == 1: inducing_points = inducing_points.unsqueeze(-1) # We're going to create two set of inducing points # One set for computing the mean, one set for computing the variance if self.mean_var_batch_dim is not None: inducing_points = torch.stack([inducing_points, inducing_points], dim=(self.mean_var_batch_dim - 2)) else: inducing_points = torch.stack([inducing_points, inducing_points], dim=-3) super().__init__( model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val=jitter_val ) def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]: # If we haven't explicitly marked a dimension as batch, add the corresponding batch dimension to the input if self.mean_var_batch_dim is None: x = x.unsqueeze(-3) else: x = x.unsqueeze(self.mean_var_batch_dim - 2) return super()._expand_inputs(x, inducing_points) def forward( self, x: Tensor, inducing_points: Tensor, inducing_values: Tensor, variational_inducing_covar: Optional[LinearOperator] = None, **kwargs, ) -> MultivariateNormal: # We'll compute the covariance, and cross-covariance terms for both the # pred-mean and pred-covar, using their different inducing points (and maybe kernel hypers) mean_var_batch_dim = self.mean_var_batch_dim or -1 # Compute full prior distribution full_inputs = torch.cat([inducing_points, x], dim=-2) full_output = self.model.forward(full_inputs, **kwargs) full_covar = full_output.lazy_covariance_matrix # Covariance terms num_induc = inducing_points.size(-2) test_mean = full_output.mean[..., num_induc:] induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter(self.jitter_val) induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense() data_data_covar = full_covar[..., num_induc:, num_induc:] # Compute interpolation terms # K_ZZ^{-1/2} K_ZX # K_ZZ^{-1/2} \mu_Z L = self._cholesky_factor(induc_induc_covar) if L.shape != induc_induc_covar.shape: # Aggressive caching can cause nasty shape incompatibilies when evaluating with different batch shapes # TODO: Use a hook to make this cleaner try: pop_from_cache_ignore_args(self, "cholesky_factor") except CachingError: pass L = self._cholesky_factor(induc_induc_covar) interp_term = L.solve(induc_data_covar.double()).to(full_inputs.dtype) mean_interp_term = interp_term.select(mean_var_batch_dim - 2, 0) var_interp_term = interp_term.select(mean_var_batch_dim - 2, 1) # Compute the mean of q(f) # k_XZ K_ZZ^{-1/2} m + \mu_X # Here we're using the terms that correspond to the mean's inducing points predictive_mean = torch.add( torch.matmul(mean_interp_term.transpose(-1, -2), inducing_values.unsqueeze(-1)).squeeze(-1), test_mean.select(mean_var_batch_dim - 1, 0), ) # Compute the covariance of q(f) # K_XX + k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1) if variational_inducing_covar is not None: middle_term = SumLinearOperator(variational_inducing_covar, middle_term) predictive_covar = SumLinearOperator( data_data_covar.add_jitter(self.jitter_val).to_dense().select(mean_var_batch_dim - 2, 1), MatmulLinearOperator(var_interp_term.transpose(-1, -2), middle_term @ var_interp_term), ) return MultivariateNormal(predictive_mean, predictive_covar) def kl_divergence(self) -> Tensor: variational_dist = self.variational_distribution prior_dist = self.prior_distribution mean_dist = Delta(variational_dist.mean) covar_dist = MultivariateNormal( torch.zeros_like(variational_dist.mean), variational_dist.lazy_covariance_matrix ) return kl_divergence(mean_dist, prior_dist) + kl_divergence(covar_dist, prior_dist)