Source code for gpytorch.lazy.block_diag_lazy_tensor

#!/usr/bin/env python3

from typing import Optional, Tuple

import torch
from torch import Tensor

from ..utils.memoize import cached
from .block_lazy_tensor import BlockLazyTensor
from .lazy_tensor import LazyTensor


[docs]class BlockDiagLazyTensor(BlockLazyTensor): """ Represents a lazy tensor that is the block diagonal of square matrices. The block_dim attribute specifies which dimension of the base LazyTensor specifies the blocks. For example, (with `block_dim=-3` a `k x n x n` tensor represents `k` `n x n` blocks (a `kn x kn` matrix). A `b x k x n x n` tensor represents `k` `b x n x n` blocks (a `b x kn x kn` batch matrix). Args: base_lazy_tensor (LazyTensor or Tensor): Must be at least 3 dimensional. block_dim (int): The dimension that specifies the blocks. """ @property def num_blocks(self): return self.base_lazy_tensor.size(-3) def _add_batch_dim(self, other): *batch_shape, num_rows, num_cols = other.shape batch_shape = list(batch_shape) batch_shape.append(self.num_blocks) other = other.view(*batch_shape, num_rows // self.num_blocks, num_cols) return other @cached(name="cholesky") def _cholesky(self, upper=False): from .triangular_lazy_tensor import TriangularLazyTensor chol = self.__class__(self.base_lazy_tensor.cholesky(upper=upper)) return TriangularLazyTensor(chol, upper=upper) def _cholesky_solve(self, rhs, upper: bool = False): rhs = self._add_batch_dim(rhs) res = self.base_lazy_tensor._cholesky_solve(rhs, upper=upper) res = self._remove_batch_dim(res) return res def _get_indices(self, row_index, col_index, *batch_indices): # Figure out what block the row/column indices belong to row_index_block = torch.div(row_index, self.base_lazy_tensor.size(-2), rounding_mode="floor") col_index_block = torch.div(col_index, self.base_lazy_tensor.size(-1), rounding_mode="floor") # Find the row/col index within each block row_index = row_index.fmod(self.base_lazy_tensor.size(-2)) col_index = col_index.fmod(self.base_lazy_tensor.size(-1)) # If the row/column blocks do not agree, then we have off diagonal elements # These elements should be zeroed out res = self.base_lazy_tensor._get_indices(row_index, col_index, *batch_indices, row_index_block) res = res * torch.eq(row_index_block, col_index_block).type_as(res) return res def _remove_batch_dim(self, other): shape = list(other.shape) del shape[-3] shape[-2] *= self.num_blocks other = other.reshape(*shape) return other def _root_decomposition(self): return self.__class__(self.base_lazy_tensor._root_decomposition()) def _root_inv_decomposition(self, initial_vectors=None): return self.__class__(self.base_lazy_tensor._root_inv_decomposition(initial_vectors)) def _size(self): shape = list(self.base_lazy_tensor.shape) shape[-2] *= shape[-3] shape[-1] *= shape[-3] del shape[-3] return torch.Size(shape) def _solve(self, rhs, preconditioner, num_tridiag=0): if num_tridiag: return super()._solve(rhs, preconditioner, num_tridiag=num_tridiag) else: rhs = self._add_batch_dim(rhs) res = self.base_lazy_tensor._solve(rhs, preconditioner, num_tridiag=None) res = self._remove_batch_dim(res) return res def diag(self): res = self.base_lazy_tensor.diag().contiguous() return res.view(*self.batch_shape, self.size(-1)) def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True): if inv_quad_rhs is not None: inv_quad_rhs = self._add_batch_dim(inv_quad_rhs) inv_quad_res, logdet_res = self.base_lazy_tensor.inv_quad_logdet( inv_quad_rhs, logdet, reduce_inv_quad=reduce_inv_quad ) if inv_quad_res is not None and inv_quad_res.numel(): if reduce_inv_quad: inv_quad_res = inv_quad_res.view(*self.base_lazy_tensor.batch_shape) inv_quad_res = inv_quad_res.sum(-1) else: inv_quad_res = inv_quad_res.view(*self.base_lazy_tensor.batch_shape, inv_quad_res.size(-1)) inv_quad_res = inv_quad_res.sum(-2) if logdet_res is not None and logdet_res.numel(): logdet_res = logdet_res.view(*logdet_res.shape).sum(-1) return inv_quad_res, logdet_res def matmul(self, other): from .diag_lazy_tensor import DiagLazyTensor # this is trivial if we multiply two BlockDiagLazyTensors if isinstance(other, BlockDiagLazyTensor): return BlockDiagLazyTensor(self.base_lazy_tensor @ other.base_lazy_tensor) # special case if we have a DiagLazyTensor if isinstance(other, DiagLazyTensor): diag_reshape = other._diag.view(*self.base_lazy_tensor.shape[:-2], 1, -1) return BlockDiagLazyTensor(self.base_lazy_tensor * diag_reshape) return super().matmul(other) @cached(name="svd") def _svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]: U, S, V = self.base_lazy_tensor.svd() # Doesn't make much sense to sort here, o/w we lose the structure S = S.reshape(*S.shape[:-2], S.shape[-2:].numel()) # can assume that block_dim is -3 here U = self.__class__(U) V = self.__class__(V) return U, S, V def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LazyTensor]]: evals, evecs = self.base_lazy_tensor.symeig(eigenvectors=eigenvectors) # Doesn't make much sense to sort here, o/w we lose the structure evals = evals.reshape(*evals.shape[:-2], evals.shape[-2:].numel()) if eigenvectors: evecs = self.__class__(evecs) # can assume that block_dim is -3 here else: evecs = None return evals, evecs