#!/usr/bin/env python3
import torch
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.getitem import _compute_getitem_size
from ..utils.memoize import cached
from .lazy_tensor import LazyTensor
[docs]class ZeroLazyTensor(LazyTensor):
"""
Special LazyTensor representing zero.
"""
def __init__(self, *sizes, dtype=None, device=None):
super(ZeroLazyTensor, self).__init__(*sizes)
self.sizes = list(sizes)
self._dtype = dtype or torch.get_default_dtype()
self._device = device or torch.device("cpu")
@property
def dtype(self):
return self._dtype
@property
def device(self):
return self._device
def _expand_batch(self, batch_shape):
return self.__class__(*batch_shape, *self.sizes[-2:], dtype=self._dtype, device=self._device)
def _get_indices(self, row_index, col_index, *batch_indices):
new_size = _compute_getitem_size(self, batch_indices + (row_index, col_index))
return ZeroLazyTensor(*new_size)
def _getitem(self, row_index, col_index, *batch_indices):
new_size = _compute_getitem_size(self, batch_indices + (row_index, col_index))
return ZeroLazyTensor(*new_size)
def _matmul(self, rhs):
rhs_size_ind = -2 if rhs.ndimension() > 1 else -1
if self.size(-1) != rhs.size(rhs_size_ind):
raise RuntimeError("Size mismatch, self: {}, rhs: {}".format(self.size(), rhs.size()))
new_m = self.size(-2)
if rhs_size_ind == -1:
*batch_shape, m = rhs.shape
output_shape = (*batch_shape, new_m)
else:
*batch_shape, m, n = rhs.shape
output_shape = (*batch_shape, new_m, n)
return torch.zeros(*output_shape, dtype=rhs.dtype, device=rhs.device)
def _prod_batch(self, dim):
sizes = list(self.sizes)
del sizes[dim]
return self.__class__(*sizes, dtype=self._dtype, device=self._device)
def _quad_form_derivative(self, left_vecs, right_vecs):
raise RuntimeError("Backwards through a ZeroLazyTensor is not possible")
def _root_decomposition(self):
raise RuntimeError("ZeroLazyTensors are not positive definite!")
def _root_inv_decomposition(self, initial_vectors=None):
raise RuntimeError("ZeroLazyTensors are not positive definite!")
def _root_decomposition_size(self):
raise RuntimeError("ZeroLazyTensors are not positive definite!")
def _size(self):
return torch.Size(self.sizes)
def _sum_batch(self, dim):
sizes = list(self.sizes)
del sizes[dim]
return self.__class__(*sizes, dtype=self._dtype, device=self._device)
def _t_matmul(self, rhs):
rhs_size_ind = -2 if rhs.ndimension() > 1 else -1
if self.size(-2) != rhs.size(rhs_size_ind):
raise RuntimeError("Size mismatch, self: {}, rhs: {}".format(self.size(), rhs.size()))
new_m = self.size(-1)
if rhs_size_ind == -1:
*batch_shape, m = rhs.shape
output_shape = (*batch_shape, new_m)
else:
*batch_shape, m, n = rhs.shape
output_shape = (*batch_shape, new_m, n)
return torch.zeros(*output_shape, dtype=rhs.dtype, device=rhs.device)
def _transpose_nonbatch(self):
return self.transpose(-2, -1)
def _unsqueeze_batch(self, dim):
sizes = self.sizes.copy()
sizes.insert(dim, 1)
return self.__class__(*sizes, dtype=self._dtype, device=self._device)
def add_diag(self, diag):
from .diag_lazy_tensor import DiagLazyTensor
if self.size(-1) != self.size(-2):
raise RuntimeError("add_diag only defined for square matrices")
if self.ndimension() == 3:
if diag.ndimension() == 0:
diag = diag.view(1, 1).expand(self.size(0), self.size(1))
elif diag.ndimension() == 1:
diag = diag.unsqueeze(0).expand(self.size(0), self.size(1))
elif diag.ndimension() == 2:
diag = diag.expand(self.size(0), self.size(1))
else:
raise RuntimeError(
"For a 3D tensor ({}), add_diag expects a 1D or 2D diag. "
"Got size ({})".format(self.size(), diag.size())
)
else:
if diag.ndimension() == 0:
diag = diag.view(1).expand(self.size(0))
elif diag.ndimension() == 1:
diag = diag.expand(self.size(0))
else:
raise RuntimeError(
"For a 3D tensor ({}), add_diag expects a 1D or 2D diag. "
"Got size ({})".format(self.size(), diag.size())
)
res = DiagLazyTensor(diag)
if res.size() != self.size():
raise RuntimeError(
"Diag dimensions are incompatible with the base LazyTensor dimensions. "
"Diag size corresponds to a {} Tensor - expected {}".format(res.size(), self.size())
)
return res
def diag(self):
shape = self.shape
if shape[-1] != shape[-2]:
raise RuntimeError("diag works on square matrices (or batches)")
return torch.zeros(shape[:-1], dtype=self.dtype, device=self.device)
@cached
def evaluate(self):
return torch.zeros(*self.sizes)
def inv_matmul(self, right_tensor, left_tensor=None):
raise RuntimeError("ZeroLazyTensors are not invertible!")
def inv_quad(self, tensor):
raise RuntimeError("ZeroLazyTensors are not invertible!")
def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
raise RuntimeError("ZeroLazyTensors are not invertible!")
def logdet(self):
return torch.log(torch.tensor(0.0))
def matmul(self, tensor):
tensor_size_ind = -2 if tensor.ndimension() > 1 else -1
if self.size(-1) != tensor.size(tensor_size_ind):
raise RuntimeError("Size mismatch, self: {}, tensor: {}".format(self.size(), tensor.size()))
new_m = self.size(-2)
if tensor_size_ind == -1:
*batch_shape, m = tensor.shape
output_shape = (*batch_shape, new_m)
else:
*batch_shape, m, n = tensor.shape
output_shape = (*batch_shape, new_m, n)
return ZeroLazyTensor(*output_shape, dtype=tensor.dtype, device=tensor.device)
def mul(self, other):
shape = _mul_broadcast_shape(self.shape, other.shape)
return self.__class__(*shape, dtype=self._dtype, device=self._device)
def transpose(self, dim1, dim2):
sizes = self.sizes.copy()
tmp = sizes[dim1]
sizes[dim1] = sizes[dim2]
sizes[dim2] = tmp
return ZeroLazyTensor(*sizes)
def __add__(self, other):
return other
def __div__(self, other):
return self
def __mul__(self, other):
return self