Source code for gpytorch.lazy.diag_lazy_tensor

#!/usr/bin/env python3

from typing import Optional, Tuple

import torch
from torch import Tensor

from .. import settings
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.memoize import cached
from .block_diag_lazy_tensor import BlockDiagLazyTensor
from .lazy_tensor import LazyTensor
from .non_lazy_tensor import NonLazyTensor
from .triangular_lazy_tensor import TriangularLazyTensor


[docs]class DiagLazyTensor(TriangularLazyTensor): def __init__(self, diag): """ Diagonal lazy tensor. Supports arbitrary batch sizes. Args: diag (Tensor): A `b1 x ... x bk x n` Tensor, representing a `b1 x ... x bk`-sized batch of `n x n` diagonal matrices """ super(TriangularLazyTensor, self).__init__(diag) self._diag = diag def __add__(self, other): if isinstance(other, DiagLazyTensor): return self.add_diag(other._diag) from .added_diag_lazy_tensor import AddedDiagLazyTensor return AddedDiagLazyTensor(other, self) @cached(name="cholesky", ignore_args=True) def _cholesky(self, upper=False): return self.sqrt() def _cholesky_solve(self, rhs): return rhs / self._diag.unsqueeze(-1).pow(2) def _expand_batch(self, batch_shape): return self.__class__(self._diag.expand(*batch_shape, self._diag.size(-1))) def _get_indices(self, row_index, col_index, *batch_indices): res = self._diag[(*batch_indices, row_index)] # If row and col index don't agree, then we have off diagonal elements # Those should be zero'd out res = res * torch.eq(row_index, col_index).to(device=res.device, dtype=res.dtype) return res def _matmul(self, rhs): # to perform matrix multiplication with diagonal matrices we can just # multiply element-wise with the diagonal (using proper broadcasting) if rhs.ndimension() == 1: return self._diag * rhs # special case if we have a NonLazyTensor if isinstance(rhs, NonLazyTensor): return NonLazyTensor(self._diag.unsqueeze(-1) * rhs.tensor) return self._diag.unsqueeze(-1) * rhs def _mul_constant(self, constant): return self.__class__(self._diag * constant.unsqueeze(-1)) def _mul_matrix(self, other): return DiagLazyTensor(self.diag() * other.diag()) def _prod_batch(self, dim): return self.__class__(self._diag.prod(dim)) def _quad_form_derivative(self, left_vecs, right_vecs): # TODO: Use proper batching for input vectors (prepand to shape rathern than append) if not self._diag.requires_grad: return (None,) res = left_vecs * right_vecs if res.ndimension() > self._diag.ndimension(): res = res.sum(-1) return (res,) def _root_decomposition(self): return self.sqrt() def _root_inv_decomposition(self, initial_vectors=None): return self.inverse().sqrt() def _size(self): return self._diag.shape + self._diag.shape[-1:] def _sum_batch(self, dim): return self.__class__(self._diag.sum(dim)) def _t_matmul(self, rhs): # Diagonal matrices always commute return self._matmul(rhs) def _transpose_nonbatch(self): return self def abs(self): return self.__class__(self._diag.abs()) def add_diag(self, added_diag): shape = _mul_broadcast_shape(self._diag.shape, added_diag.shape) return DiagLazyTensor(self._diag.expand(shape) + added_diag.expand(shape)) def diag(self): return self._diag @cached def evaluate(self): if self._diag.dim() == 0: return self._diag return torch.diag_embed(self._diag) def exp(self): return self.__class__(self._diag.exp()) def inverse(self): return self.__class__(self._diag.reciprocal()) def inv_matmul(self, right_tensor, left_tensor=None): res = self.inverse()._matmul(right_tensor) if left_tensor is not None: res = left_tensor @ res return res def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True): # TODO: Use proper batching for inv_quad_rhs (prepand to shape rathern than append) if inv_quad_rhs is None: rhs_batch_shape = torch.Size() else: rhs_batch_shape = inv_quad_rhs.shape[1 + self.batch_dim :] if inv_quad_rhs is None: inv_quad_term = torch.empty(0, dtype=self.dtype, device=self.device) else: diag = self._diag for _ in rhs_batch_shape: diag = diag.unsqueeze(-1) inv_quad_term = inv_quad_rhs.div(diag).mul(inv_quad_rhs).sum(-(1 + len(rhs_batch_shape))) if reduce_inv_quad: inv_quad_term = inv_quad_term.sum(-1) if not logdet: logdet_term = torch.empty(0, dtype=self.dtype, device=self.device) else: logdet_term = self._diag.log().sum(-1) return inv_quad_term, logdet_term def log(self): return self.__class__(self._diag.log()) def matmul(self, other): # this is trivial if we multiply two DiagLazyTensors if isinstance(other, DiagLazyTensor): return DiagLazyTensor(self._diag * other._diag) # special case if we have a NonLazyTensor if isinstance(other, NonLazyTensor): return NonLazyTensor(self._diag.unsqueeze(-1) * other.tensor) # special case if we have a BlockDiagLazyTensor if isinstance(other, BlockDiagLazyTensor): diag_reshape = self._diag.view(*other.base_lazy_tensor.shape[:-1], 1) return BlockDiagLazyTensor(diag_reshape * other.base_lazy_tensor) # special case if we have a TriangularLazyTensor if isinstance(other, TriangularLazyTensor): return TriangularLazyTensor(self._diag.unsqueeze(-1) * other._tensor, upper=other.upper) return super().matmul(other) def sqrt(self): return self.__class__(self._diag.sqrt()) def sqrt_inv_matmul(self, rhs, lhs=None): matrix_inv_root = self._root_inv_decomposition() if lhs is None: return matrix_inv_root.matmul(rhs) else: sqrt_inv_matmul = lhs @ matrix_inv_root.matmul(rhs) inv_quad = (matrix_inv_root @ lhs.transpose(-2, -1)).transpose(-2, -1).pow(2).sum(dim=-1) return sqrt_inv_matmul, inv_quad def zero_mean_mvn_samples(self, num_samples): base_samples = torch.randn(num_samples, *self._diag.shape, dtype=self.dtype, device=self.device) return base_samples * self._diag.sqrt() @cached(name="svd") def _svd(self) -> Tuple[LazyTensor, Tensor, LazyTensor]: evals, evecs = self.symeig(eigenvectors=True) S = torch.abs(evals) U = evecs V = evecs * torch.sign(evals).unsqueeze(-1) return U, S, V def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LazyTensor]]: evals = self._diag if eigenvectors: diag_values = torch.ones(evals.shape[:-1], device=evals.device, dtype=evals.dtype).unsqueeze(-1) evecs = ConstantDiagLazyTensor(diag_values, diag_shape=evals.shape[-1]) else: evecs = None return evals, evecs
class ConstantDiagLazyTensor(DiagLazyTensor): def __init__(self, diag_values, diag_shape): """ Diagonal lazy tensor with constant entries. Supports arbitrary batch sizes. Used e.g. for adding jitter to matrices. Args: diag_values (Tensor): A `b1 x ... x bk x 1` Tensor, representing a `b1 x ... x bk`-sized batch of `diag_shape x diag_shape` diagonal matrices diag_shape (int): The (non-batch) dimension of the (square) matrix """ if settings.debug.on(): if not (diag_values.dim() and diag_values.size(-1) == 1): raise ValueError( f"diag_values argument to ConstantDiagLazyTensor needs to have a final " f"singleton dimension. Instead, got a value with shape {diag_values.shape}." ) super(TriangularLazyTensor, self).__init__(diag_values, diag_shape=diag_shape) self.diag_values = diag_values self.diag_shape = diag_shape def __add__(self, other): if isinstance(other, ConstantDiagLazyTensor): if other.shape[-1] == self.shape[-1]: return ConstantDiagLazyTensor(self.diag_values + other.diag_values, self.diag_shape) raise RuntimeError( f"Trailing batch shapes must match for adding two ConstantDiagLazyTensors. " f"Instead, got shapes of {other.shape} and {self.shape}." ) return super().__add__(other) @property def _diag(self): return self.diag_values.expand(*self.diag_values.shape[:-1], self.diag_shape) def _expand_batch(self, batch_shape): return self.__class__(self.diag_values.expand(*batch_shape, 1), diag_shape=self.diag_shape) def _mul_constant(self, constant): return self.__class__(self.diag_values * constant, diag_shape=self.diag_shape) def _mul_matrix(self, other): if isinstance(other, ConstantDiagLazyTensor): if not self.diag_shape == other.diag_shape: raise ValueError( "Dimension Mismatch: Must have same diag_shape, but got " f"{self.diag_shape} and {other.diag_shape}" ) return self.__class__(self.diag_values * other.diag_values, diag_shape=self.diag_shape) return super()._mul_matrix(other) def _prod_batch(self, dim): return self.__class__(self.diag_values.prod(dim), diag_shape=self.diag_shape) def _quad_form_derivative(self, left_vecs, right_vecs): # TODO: Use proper batching for input vectors (prepand to shape rathern than append) if not self.diag_values.requires_grad: return (None,) res = (left_vecs * right_vecs).sum(dim=[-1, -2]) res = res.unsqueeze(-1) return (res,) def _sum_batch(self, dim): return ConstantDiagLazyTensor(self.diag_values.sum(dim), diag_shape=self.diag_shape) def abs(self): return ConstantDiagLazyTensor(self.diag_values.abs(), diag_shape=self.diag_shape) def exp(self): return ConstantDiagLazyTensor(self.diag_values.exp(), diag_shape=self.diag_shape) def inverse(self): return ConstantDiagLazyTensor(self.diag_values.reciprocal(), diag_shape=self.diag_shape) def log(self): return ConstantDiagLazyTensor(self.diag_values.log(), diag_shape=self.diag_shape) def matmul(self, other): if isinstance(other, ConstantDiagLazyTensor): return self._mul_matrix(other) return super().matmul(other) def sqrt(self): return ConstantDiagLazyTensor(self.diag_values.sqrt(), diag_shape=self.diag_shape)