Source code for gpytorch.kernels.inducing_point_kernel

#!/usr/bin/env python3

import copy
import math
from typing import Optional, Tuple

import torch
from linear_operator import to_dense
from linear_operator.operators import (
    DiagLinearOperator,
    LowRankRootAddedDiagLinearOperator,
    LowRankRootLinearOperator,
    MatmulLinearOperator,
)
from linear_operator.utils.cholesky import psd_safe_cholesky
from torch import Tensor

from .. import settings
from ..distributions import MultivariateNormal
from ..likelihoods import Likelihood
from ..mlls import InducingPointKernelAddedLossTerm
from ..models import exact_prediction_strategies
from .kernel import Kernel


[docs]class InducingPointKernel(Kernel): def __init__( self, base_kernel: Kernel, inducing_points: Tensor, likelihood: Likelihood, active_dims: Optional[Tuple[int, ...]] = None, ): super(InducingPointKernel, self).__init__(active_dims=active_dims) self.base_kernel = base_kernel self.likelihood = likelihood if inducing_points.ndimension() == 1: inducing_points = inducing_points.unsqueeze(-1) self.register_parameter(name="inducing_points", parameter=torch.nn.Parameter(inducing_points)) self.register_added_loss_term("inducing_point_loss_term") def _clear_cache(self): if hasattr(self, "_cached_kernel_mat"): del self._cached_kernel_mat if hasattr(self, "_cached_kernel_inv_root"): del self._cached_kernel_inv_root @property def _inducing_mat(self): if not self.training and hasattr(self, "_cached_kernel_mat"): return self._cached_kernel_mat else: res = to_dense(self.base_kernel(self.inducing_points, self.inducing_points)) if not self.training: self._cached_kernel_mat = res return res @property def _inducing_inv_root(self): if not self.training and hasattr(self, "_cached_kernel_inv_root"): return self._cached_kernel_inv_root else: chol = psd_safe_cholesky(self._inducing_mat, upper=True) eye = torch.eye(chol.size(-1), device=chol.device, dtype=chol.dtype) inv_root = torch.linalg.solve_triangular(chol, eye, upper=True) res = inv_root if not self.training: self._cached_kernel_inv_root = res return res def _get_covariance(self, x1, x2): k_ux1 = to_dense(self.base_kernel(x1, self.inducing_points)) if torch.equal(x1, x2): covar = LowRankRootLinearOperator(k_ux1.matmul(self._inducing_inv_root)) # Diagonal correction for predictive posterior if not self.training and settings.sgpr_diagonal_correction.on(): correction = (self.base_kernel(x1, x2, diag=True) - covar.diagonal(dim1=-1, dim2=-2)).clamp(0, math.inf) covar = LowRankRootAddedDiagLinearOperator(covar, DiagLinearOperator(correction)) else: k_ux2 = to_dense(self.base_kernel(x2, self.inducing_points)) covar = MatmulLinearOperator( k_ux1.matmul(self._inducing_inv_root), k_ux2.matmul(self._inducing_inv_root).transpose(-1, -2) ) return covar def _covar_diag(self, inputs): if inputs.ndimension() == 1: inputs = inputs.unsqueeze(1) # Get diagonal of covar covar_diag = to_dense(self.base_kernel(inputs, diag=True)) return DiagLinearOperator(covar_diag) def forward(self, x1, x2, diag=False, **kwargs): covar = self._get_covariance(x1, x2) if self.training: if not torch.equal(x1, x2): raise RuntimeError("x1 should equal x2 in training mode") zero_mean = torch.zeros_like(x1.select(-1, 0)) new_added_loss_term = InducingPointKernelAddedLossTerm( MultivariateNormal(zero_mean, self._covar_diag(x1)), MultivariateNormal(zero_mean, covar), self.likelihood, ) self.update_added_loss_term("inducing_point_loss_term", new_added_loss_term) if diag: return covar.diagonal(dim1=-1, dim2=-2) else: return covar def num_outputs_per_input(self, x1, x2): return self.base_kernel.num_outputs_per_input(x1, x2) def __deepcopy__(self, memo): replace_inv_root = False replace_kernel_mat = False if hasattr(self, "_cached_kernel_inv_root"): replace_inv_root = True kernel_inv_root = self._cached_kernel_inv_root if hasattr(self, "_cached_kernel_mat"): replace_kernel_mat = True kernel_mat = self._cached_kernel_mat cp = self.__class__( base_kernel=copy.deepcopy(self.base_kernel), inducing_points=copy.deepcopy(self.inducing_points), likelihood=self.likelihood, active_dims=self.active_dims, ) if replace_inv_root: cp._cached_kernel_inv_root = kernel_inv_root if replace_kernel_mat: cp._cached_kernel_mat = kernel_mat return cp def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood): # Allow for fast variances return exact_prediction_strategies.SGPRPredictionStrategy( train_inputs, train_prior_dist, train_labels, likelihood )