Source code for gpytorch.kernels.inducing_point_kernel

#!/usr/bin/env python3

import copy
import math
from typing import Optional, Tuple

import torch
from torch import Tensor

from .. import settings
from ..distributions import MultivariateNormal
from ..lazy import DiagLazyTensor, LowRankRootAddedDiagLazyTensor, LowRankRootLazyTensor, MatmulLazyTensor, delazify
from ..likelihoods import Likelihood
from ..mlls import InducingPointKernelAddedLossTerm
from ..models import exact_prediction_strategies
from ..utils.cholesky import psd_safe_cholesky
from .kernel import Kernel


[docs]class InducingPointKernel(Kernel): def __init__( self, base_kernel: Kernel, inducing_points: Tensor, likelihood: Likelihood, active_dims: Optional[Tuple[int, ...]] = None, ): super(InducingPointKernel, self).__init__(active_dims=active_dims) self.base_kernel = base_kernel self.likelihood = likelihood if inducing_points.ndimension() == 1: inducing_points = inducing_points.unsqueeze(-1) self.register_parameter(name="inducing_points", parameter=torch.nn.Parameter(inducing_points)) self.register_added_loss_term("inducing_point_loss_term") def _clear_cache(self): if hasattr(self, "_cached_kernel_mat"): del self._cached_kernel_mat if hasattr(self, "_cached_kernel_inv_root"): del self._cached_kernel_inv_root @property def _inducing_mat(self): if not self.training and hasattr(self, "_cached_kernel_mat"): return self._cached_kernel_mat else: res = delazify(self.base_kernel(self.inducing_points, self.inducing_points)) if not self.training: self._cached_kernel_mat = res return res @property def _inducing_inv_root(self): if not self.training and hasattr(self, "_cached_kernel_inv_root"): return self._cached_kernel_inv_root else: chol = psd_safe_cholesky(self._inducing_mat, upper=True) eye = torch.eye(chol.size(-1), device=chol.device, dtype=chol.dtype) inv_root = torch.triangular_solve(eye, chol)[0] res = inv_root if not self.training: self._cached_kernel_inv_root = res return res def _get_covariance(self, x1, x2): k_ux1 = delazify(self.base_kernel(x1, self.inducing_points)) if torch.equal(x1, x2): covar = LowRankRootLazyTensor(k_ux1.matmul(self._inducing_inv_root)) # Diagonal correction for predictive posterior if not self.training and settings.sgpr_diagonal_correction.on(): correction = (self.base_kernel(x1, x2, diag=True) - covar.diag()).clamp(0, math.inf) covar = LowRankRootAddedDiagLazyTensor(covar, DiagLazyTensor(correction)) else: k_ux2 = delazify(self.base_kernel(x2, self.inducing_points)) covar = MatmulLazyTensor( k_ux1.matmul(self._inducing_inv_root), k_ux2.matmul(self._inducing_inv_root).transpose(-1, -2) ) return covar def _covar_diag(self, inputs): if inputs.ndimension() == 1: inputs = inputs.unsqueeze(1) # Get diagonal of covar covar_diag = delazify(self.base_kernel(inputs, diag=True)) return DiagLazyTensor(covar_diag) def forward(self, x1, x2, diag=False, **kwargs): covar = self._get_covariance(x1, x2) if self.training: if not torch.equal(x1, x2): raise RuntimeError("x1 should equal x2 in training mode") zero_mean = torch.zeros_like(x1.select(-1, 0)) new_added_loss_term = InducingPointKernelAddedLossTerm( MultivariateNormal(zero_mean, self._covar_diag(x1)), MultivariateNormal(zero_mean, covar), self.likelihood, ) self.update_added_loss_term("inducing_point_loss_term", new_added_loss_term) if diag: return covar.diag() else: return covar def num_outputs_per_input(self, x1, x2): return self.base_kernel.num_outputs_per_input(x1, x2) def __deepcopy__(self, memo): replace_inv_root = False replace_kernel_mat = False if hasattr(self, "_cached_kernel_inv_root"): replace_inv_root = True kernel_inv_root = self._cached_kernel_inv_root if hasattr(self, "_cached_kernel_mat"): replace_kernel_mat = True kernel_mat = self._cached_kernel_mat cp = self.__class__( base_kernel=copy.deepcopy(self.base_kernel), inducing_points=copy.deepcopy(self.inducing_points), likelihood=self.likelihood, active_dims=self.active_dims, ) if replace_inv_root: cp._cached_kernel_inv_root = kernel_inv_root if replace_kernel_mat: cp._cached_kernel_mat = kernel_mat return cp def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood): # Allow for fast variances return exact_prediction_strategies.SGPRPredictionStrategy( train_inputs, train_prior_dist, train_labels, likelihood )