Source code for gpytorch.lazy.added_diag_lazy_tensor

#!/usr/bin/env python3

import warnings
from typing import Optional, Tuple

import torch
from torch import Tensor

from .. import settings
from ..utils import broadcasting
from ..utils.memoize import cached
from ..utils.warnings import NumericalWarning
from .diag_lazy_tensor import ConstantDiagLazyTensor, DiagLazyTensor
from .lazy_tensor import LazyTensor
from .psd_sum_lazy_tensor import PsdSumLazyTensor
from .root_lazy_tensor import RootLazyTensor
from .sum_lazy_tensor import SumLazyTensor


[docs]class AddedDiagLazyTensor(SumLazyTensor): """ A SumLazyTensor, but of only two lazy tensors, the second of which must be a DiagLazyTensor. """ def __init__(self, *lazy_tensors, preconditioner_override=None): lazy_tensors = list(lazy_tensors) super(AddedDiagLazyTensor, self).__init__(*lazy_tensors, preconditioner_override=preconditioner_override) if len(lazy_tensors) > 2: raise RuntimeError("An AddedDiagLazyTensor can only have two components") broadcasting._mul_broadcast_shape(lazy_tensors[0].shape, lazy_tensors[1].shape) if isinstance(lazy_tensors[0], DiagLazyTensor) and isinstance(lazy_tensors[1], DiagLazyTensor): raise RuntimeError("Trying to lazily add two DiagLazyTensors. Create a single DiagLazyTensor instead.") elif isinstance(lazy_tensors[0], DiagLazyTensor): self._diag_tensor = lazy_tensors[0] self._lazy_tensor = lazy_tensors[1] elif isinstance(lazy_tensors[1], DiagLazyTensor): self._diag_tensor = lazy_tensors[1] self._lazy_tensor = lazy_tensors[0] else: raise RuntimeError("One of the LazyTensors input to AddedDiagLazyTensor must be a DiagLazyTensor!") self.preconditioner_override = preconditioner_override # Placeholders self._constant_diag = None self._noise = None self._piv_chol_self = None # <- Doesn't need to be an attribute, but used for testing purposes self._precond_lt = None self._precond_logdet_cache = None self._q_cache = None self._r_cache = None def _matmul(self, rhs): return torch.addcmul(self._lazy_tensor._matmul(rhs), self._diag_tensor._diag.unsqueeze(-1), rhs) def add_diag(self, added_diag): return self.__class__(self._lazy_tensor, self._diag_tensor.add_diag(added_diag)) def __add__(self, other): from .diag_lazy_tensor import DiagLazyTensor if isinstance(other, DiagLazyTensor): return self.__class__(self._lazy_tensor, self._diag_tensor + other) else: return self.__class__(self._lazy_tensor + other, self._diag_tensor) def _preconditioner(self): r""" Here we use a partial pivoted Cholesky preconditioner: K \approx L L^T + D where L L^T is a low rank approximation, and D is a diagonal. We can compute the preconditioner's inverse using Woodbury (L L^T + D)^{-1} = D^{-1} - D^{-1} L (I + L D^{-1} L^T)^{-1} L^T D^{-1} This function returns: - A function `precondition_closure` that computes the solve (L L^T + D)^{-1} x - A LazyTensor `precondition_lt` that represents (L L^T + D) - The log determinant of (L L^T + D) """ if self.preconditioner_override is not None: return self.preconditioner_override(self) if settings.max_preconditioner_size.value() == 0 or self.size(-1) < settings.min_preconditioning_size.value(): return None, None, None # Cache a QR decomposition [Q; Q'] R = [D^{-1/2}; L] # This makes it fast to compute solves and log determinants with it # # Through woodbury, (L L^T + D)^{-1} reduces down to (D^{-1} - D^{-1/2} Q Q^T D^{-1/2}) # Through matrix determinant lemma, log |L L^T + D| reduces down to 2 log |R| if self._q_cache is None: max_iter = settings.max_preconditioner_size.value() self._piv_chol_self = self._lazy_tensor.pivoted_cholesky(rank=max_iter) if torch.any(torch.isnan(self._piv_chol_self)).item(): warnings.warn( "NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.", NumericalWarning, ) return None, None, None self._init_cache() # NOTE: We cannot memoize this precondition closure as it causes a memory leak def precondition_closure(tensor): # This makes it fast to compute solves with it qqt = self._q_cache.matmul(self._q_cache.transpose(-2, -1).matmul(tensor)) if self._constant_diag: return (1 / self._noise) * (tensor - qqt) return (tensor / self._noise) - qqt return (precondition_closure, self._precond_lt, self._precond_logdet_cache) def _init_cache(self): *batch_shape, n, k = self._piv_chol_self.shape self._noise = self._diag_tensor.diag().unsqueeze(-1) # the check for constant diag needs to be done carefully for batches. noise_first_element = self._noise[..., :1, :] self._constant_diag = torch.equal(self._noise, noise_first_element * torch.ones_like(self._noise)) eye = torch.eye(k, dtype=self._piv_chol_self.dtype, device=self._piv_chol_self.device) eye = eye.expand(*batch_shape, k, k) if self._constant_diag: self._init_cache_for_constant_diag(eye, batch_shape, n, k) else: self._init_cache_for_non_constant_diag(eye, batch_shape, n) self._precond_lt = PsdSumLazyTensor(RootLazyTensor(self._piv_chol_self), self._diag_tensor) def _init_cache_for_constant_diag(self, eye, batch_shape, n, k): # We can factor out the noise for for both QR and solves. self._noise = self._noise.narrow(-2, 0, 1) self._q_cache, self._r_cache = torch.linalg.qr( torch.cat((self._piv_chol_self, self._noise.sqrt() * eye), dim=-2) ) self._q_cache = self._q_cache[..., :n, :] # Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2) logdet = logdet + (n - k) * self._noise.squeeze(-2).squeeze(-1).log() self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() def _init_cache_for_non_constant_diag(self, eye, batch_shape, n): # With non-constant diagonals, we cant factor out the noise as easily self._q_cache, self._r_cache = torch.linalg.qr( torch.cat((self._piv_chol_self / self._noise.sqrt(), eye), dim=-2) ) self._q_cache = self._q_cache[..., :n, :] / self._noise.sqrt() # Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2) logdet -= (1.0 / self._noise).log().sum([-1, -2]) self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() @cached(name="svd") def _svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]: if isinstance(self._diag_tensor, ConstantDiagLazyTensor): U, S_, V = self._lazy_tensor.svd() S = S_ + self._diag_tensor.diag() return U, S, V return super()._svd() def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LazyTensor]]: if isinstance(self._diag_tensor, ConstantDiagLazyTensor): evals_, evecs = self._lazy_tensor.symeig(eigenvectors=eigenvectors) evals = evals_ + self._diag_tensor.diag() return evals, evecs return super()._symeig(eigenvectors=eigenvectors)
[docs] def evaluate_kernel(self): """ Overriding this is currently necessary to allow for subclasses of AddedDiagLT to be created. For example, consider the following: >>> covar1 = covar_module(x).add_diag(torch.tensor(1.)).evaluate_kernel() >>> covar2 = covar_module(x).evaluate_kernel().add_diag(torch.tensor(1.)) Unless we override this method (or find a better solution), covar1 and covar2 might not be the same type. In particular, covar1 would *always* be a standard AddedDiagLazyTensor, but covar2 might be a subtype. """ added_diag_lazy_tsr = self.representation_tree()(*self.representation()) return added_diag_lazy_tsr._lazy_tensor + added_diag_lazy_tsr._diag_tensor