Source code for gpytorch.lazy.added_diag_lazy_tensor

#!/usr/bin/env python3

import warnings
from typing import Optional, Tuple

import torch
from torch import Tensor

from .. import settings
from ..utils import broadcasting, pivoted_cholesky
from ..utils.memoize import cached
from ..utils.warnings import NumericalWarning
from .diag_lazy_tensor import ConstantDiagLazyTensor, DiagLazyTensor
from .lazy_tensor import LazyTensor
from .psd_sum_lazy_tensor import PsdSumLazyTensor
from .root_lazy_tensor import RootLazyTensor
from .sum_lazy_tensor import SumLazyTensor


[docs]class AddedDiagLazyTensor(SumLazyTensor): """ A SumLazyTensor, but of only two lazy tensors, the second of which must be a DiagLazyTensor. """ def __init__(self, *lazy_tensors, preconditioner_override=None): lazy_tensors = list(lazy_tensors) super(AddedDiagLazyTensor, self).__init__(*lazy_tensors, preconditioner_override=preconditioner_override) if len(lazy_tensors) > 2: raise RuntimeError("An AddedDiagLazyTensor can only have two components") broadcasting._mul_broadcast_shape(lazy_tensors[0].shape, lazy_tensors[1].shape) if isinstance(lazy_tensors[0], DiagLazyTensor) and isinstance(lazy_tensors[1], DiagLazyTensor): raise RuntimeError("Trying to lazily add two DiagLazyTensors. Create a single DiagLazyTensor instead.") elif isinstance(lazy_tensors[0], DiagLazyTensor): self._diag_tensor = lazy_tensors[0] self._lazy_tensor = lazy_tensors[1] elif isinstance(lazy_tensors[1], DiagLazyTensor): self._diag_tensor = lazy_tensors[1] self._lazy_tensor = lazy_tensors[0] else: raise RuntimeError("One of the LazyTensors input to AddedDiagLazyTensor must be a DiagLazyTensor!") self.preconditioner_override = preconditioner_override # Placeholders self._constant_diag = None self._noise = None self._piv_chol_self = None # <- Doesn't need to be an attribute, but used for testing purposes self._precond_lt = None self._precond_logdet_cache = None self._q_cache = None self._r_cache = None def _matmul(self, rhs): return torch.addcmul(self._lazy_tensor._matmul(rhs), self._diag_tensor._diag.unsqueeze(-1), rhs) def add_diag(self, added_diag): return AddedDiagLazyTensor(self._lazy_tensor, self._diag_tensor.add_diag(added_diag)) def __add__(self, other): from .diag_lazy_tensor import DiagLazyTensor if isinstance(other, DiagLazyTensor): return AddedDiagLazyTensor(self._lazy_tensor, self._diag_tensor + other) else: return AddedDiagLazyTensor(self._lazy_tensor + other, self._diag_tensor) def _preconditioner(self): if self.preconditioner_override is not None: return self.preconditioner_override(self) if settings.max_preconditioner_size.value() == 0 or self.size(-1) < settings.min_preconditioning_size.value(): return None, None, None if self._q_cache is None: max_iter = settings.max_preconditioner_size.value() self._piv_chol_self = pivoted_cholesky.pivoted_cholesky(self._lazy_tensor, max_iter) if torch.any(torch.isnan(self._piv_chol_self)).item(): warnings.warn( "NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.", NumericalWarning, ) return None, None, None self._init_cache() # NOTE: We cannot memoize this precondition closure as it causes a memory leak def precondition_closure(tensor): qqt = self._q_cache.matmul(self._q_cache.transpose(-2, -1).matmul(tensor)) if self._constant_diag: return (1 / self._noise) * (tensor - qqt) return (tensor / self._noise) - qqt return (precondition_closure, self._precond_lt, self._precond_logdet_cache) def _init_cache(self): *batch_shape, n, k = self._piv_chol_self.shape self._noise = self._diag_tensor.diag().unsqueeze(-1) # the check for constant diag needs to be done carefully for batches. noise_first_element = self._noise[..., :1, :] self._constant_diag = torch.equal(self._noise, noise_first_element * torch.ones_like(self._noise)) eye = torch.eye(k, dtype=self._piv_chol_self.dtype, device=self._piv_chol_self.device) if self._constant_diag: self._init_cache_for_constant_diag(eye, batch_shape, n, k) else: self._init_cache_for_non_constant_diag(eye, batch_shape, n) self._precond_lt = PsdSumLazyTensor(RootLazyTensor(self._piv_chol_self), self._diag_tensor) def _init_cache_for_constant_diag(self, eye, batch_shape, n, k): # We can factor out the noise for for both QR and solves. self._noise = self._noise.narrow(-2, 0, 1) self._q_cache, self._r_cache = torch.qr(torch.cat((self._piv_chol_self, self._noise.sqrt() * eye), dim=-2)) self._q_cache = self._q_cache[..., :n, :] # Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2) logdet = logdet + (n - k) * self._noise.squeeze(-2).squeeze(-1).log() self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() def _init_cache_for_non_constant_diag(self, eye, batch_shape, n): # With non-constant diagonals, we cant factor out the noise as easily self._q_cache, self._r_cache = torch.qr(torch.cat((self._piv_chol_self / self._noise.sqrt(), eye))) self._q_cache = self._q_cache[..., :n, :] / self._noise.sqrt() logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2) logdet -= (1.0 / self._noise).log().sum([-1, -2]) self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() @cached(name="svd") def _svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]: if isinstance(self._diag_tensor, ConstantDiagLazyTensor): U, S_, V = self._lazy_tensor.svd() S = S_ + self._diag_tensor.diag() return U, S, V return super()._svd() def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LazyTensor]]: if isinstance(self._diag_tensor, ConstantDiagLazyTensor): evals_, evecs = self._lazy_tensor.symeig(eigenvectors=eigenvectors) evals = evals_ + self._diag_tensor.diag() return evals, evecs return super()._symeig(eigenvectors=eigenvectors)