gpytorch.lazy

LazyTensor

class gpytorch.lazy.LazyTensor(*args, **kwargs)[source]

Base class for LazyTensors in GPyTorch.

In GPyTorch, nearly all covariance matrices for Gaussian processes are handled internally as some variety of LazyTensor. A LazyTensor is an object that represents a tensor object, similar to torch.tensor, but typically differs in two ways:

  1. A tensor represented by a LazyTensor can typically be represented more efficiently than storing a full matrix. For example, a LazyTensor representing \(K=XX^{\top}\) where \(K\) is \(n \times n\) but \(X\) is \(n \times d\) might store \(X\) instead of \(K\) directly.
  2. A LazyTensor typically defines a matmul routine that performs \(KM\) that is more efficient than storing the full matrix. Using the above example, performing \(KM=X(X^{\top}M)\) requires only \(O(nd)\) time, rather than the \(O(n^2)\) time required if we were storing \(K\) directly.

In order to define a new LazyTensor class that can be used as a covariance matrix in GPyTorch, a user must define at a minimum the following methods (in each example, \(K\) denotes the matrix that the LazyTensor represents)

  • _matmul(), which performs a matrix multiplication \(KM\)
  • _size(), which returns a torch.Size containing the dimensions of \(K\).
  • _transpose_nonbatch(), which returns a transposed version of the LazyTensor

In addition to these, the following methods should be implemented for maximum efficiency

  • _quad_form_derivative(), which computes the derivative of a quadratic form with the LazyTensor (e.g. \(d (a^T X b) / dX\)).
  • _get_indices(), which returns a torch.Tensor containing elements that are given by various tensor indices.
  • _expand_batch(), which expands the batch dimensions of LazyTensors.
  • _check_args(), which performs error checking on the arguments supplied to the LazyTensor constructor.

In addition to these, a LazyTensor may need to define the following functions if it does anything interesting with the batch dimensions (e.g. sums along them, adds additional ones, etc): _unsqueeze_batch(), _getitem(), and _permute_batch(). See the documentation for these methods for details.

Note

The base LazyTensor class provides default implementations of many other operations in order to mimic the behavior of a standard tensor as closely as possible. For example, we provide default implementations of __getitem__(), __add__(), etc that either make use of other lazy tensors or exploit the functions that must be defined above.

Rather than overriding the public methods, we recommend that you override the private versions associated with these methods (e.g. - write a custom _getitem verses a custom __getitem__). This is because the public methods do quite a bit of error checking and casing that doesn’t need to be repeated.

Note

LazyTensors are designed by default to optionally represent batches of matrices. Thus, the size of a LazyTensor may be (for example) \(b \times n \times n\). Many of the methods are designed to efficiently operate on these batches if present.

add_diag(diag)[source]

Adds an element to the diagonal of the matrix.

Parameters:diag (-) –
add_jitter(jitter_val=0.001)[source]

Adds jitter (i.e., a small diagonal component) to the matrix this LazyTensor represents. This could potentially be implemented as a no-op, however this could lead to numerical instabilities, so this should only be done at the user’s risk.

add_low_rank(low_rank_mat, root_decomp_method: Optional[str] = None, root_inv_decomp_method: Optional[str] = None, generate_roots: Optional[bool] = True, **root_decomp_kwargs)[source]

Adds a low rank matrix to the matrix that this LazyTensor represents, e.g. computes A + BB^T. We then update both the tensor and its root decomposition.

We have access to, L and M where A pprox LL^T and A^{-1} pprox MM^T. We compute ilde{A} = A + BB^T = L(I + M B B^T M^T)L’ and then decompose (I + M VV^T M^T) pprox RR^T, using LR as our new root decomposition.

This strategy is described in more detail in “Kernel Interpolation for Scalable Online Gaussian Processes,” Stanton et al, AISTATS, 2021. https://arxiv.org/abs/2103.01454.

Parameters:
  • low_rank_mat (torch.tensor) – the matrix B that we are adding to A.
  • root_decomp_method (str) – how to compute the root decomposition of A.
  • root_inv_decomp_method (str) – how to compute the root inverse decomposition of A.
  • generate_roots (bool) – whether to generate the root decomposition of \(A\) even if it
  • not been created yet. (has) –
Returns:

addition of A and BB^T.

Return type:

SumLazyTensor

batch_dim

Returns the dimension of the shape over which the tensor is batched.

batch_shape

Returns the shape over which the tensor is batched.

cat_rows(cross_mat, new_mat, generate_roots=True, generate_inv_roots=True, **root_decomp_kwargs)[source]

Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g. C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat) are new components. This is most commonly used when fantasizing with kernel matrices.

We have access to A approx LL^T and A^{-1} approx RR^T, where L and R are low rank matrices resulting from root and root inverse decompositions (see the LOVE paper).

To update R, we first update L:
[A B^T; B D] = [E 0; F G][E^T F^T; 0 G^T]
Solving this matrix equation, we get:
A = EE^T = LL^T ==> E = L B = EF^T ==> F = BR D = FF^T + GG^T ==> G = (D - FF^T)^{1/2}

Once we’ve computed Z = [E 0; F G], we have that the new kernel matrix [K U; U^T S] pprox ZZ^T. Therefore, we can form a pseudo-inverse of Z directly to approximate [K U; U^T S]^{-1/2}.

This strategy is also described in “Efficient Nonmyopic Bayesian Optimization via One-Shot Multistep Trees,” Jiang et al, NeurIPS, 2020. https://arxiv.org/abs/2006.15779.

Parameters:
  • cross_mat (torch.tensor) – the matrix \(B\) we are appending to the matrix \(A\). If \(A\) is n x n, then this matrix should be n x k.
  • new_mat (torch.tensor) – the matrix \(D\) we are appending to the matrix \(A\). If \(B\) is n x k, then this matrix should be k x k.
  • generate_roots (bool) – whether to generate the root decomposition of \(A\) even if it has not been created yet.
  • generate_inv_roots (bool) – whether to generate the root inv decomposition of \(A\) even if it has not been created yet.
Returns:

concatenated lazy tensor with the new rows and columns.

Return type:

LazyTensor

cholesky(upper=False)[source]

Cholesky-factorizes the LazyTensor

Parameters:upper (bool) – False)
Returns:(LazyTensor) Cholesky factor (triangular, upper/lower depending on “upper” arg)
clone()[source]

Clones the LazyTensor (creates clones of all underlying tensors)

cpu()[source]
Returns:a new LazyTensor identical to self, but on the CPU.
Return type:LazyTensor
cuda(device_id=None)[source]

This method operates identically to torch.nn.Module.cuda().

Parameters:device_id (str, optional) – Device ID of GPU to use.
Returns:a new LazyTensor identical to self, but on the GPU.
Return type:LazyTensor
detach()[source]

Removes the LazyTensor from the current computation graph. (In practice, this function removes all Tensors that make up the LazyTensor from the computation graph.)

detach_()[source]

An in-place version of detach.

diag()[source]

As torch.diag(), returns the diagonal of the matrix \(K\) this LazyTensor represents as a vector.

Return type:torch.tensor
Returns:The diagonal of \(K\). If \(K\) is \(n \times n\), this will be a length n vector. If this LazyTensor represents a batch (e.g., is \(b \times n \times n\)), this will be a \(b \times n\) matrix of diagonals, one for each matrix in the batch.
diagonalization(method: Optional[str] = None)[source]

Returns a (usually partial) diagonalization of a symmetric PSD matrix. Options are either “lanczos” or “symeig”. “lanczos” runs Lanczos while “symeig” runs LazyTensor.symeig.

dim()[source]

Alias of ndimension()

double(device_id=None)[source]

This method operates identically to torch.Tensor.double().

evaluate()[source]

Explicitly evaluates the matrix this LazyTensor represents. This function should return a Tensor storing an exact representation of this LazyTensor.

evaluate_kernel()[source]

Return a new LazyTensor representing the same one as this one, but with all lazily evaluated kernels actually evaluated.

float(device_id=None)[source]

This method operates identically to torch.Tensor.float().

half(device_id=None)[source]

This method operates identically to torch.Tensor.half().

inv_matmul(right_tensor, left_tensor=None)[source]

Computes a linear solve (w.r.t self = \(A\)) with several right hand sides \(R\). I.e. computes

… math:

\begin{equation}
    A^{-1} R,
\end{equation}

where \(R\) is right_tensor and \(A\) is the LazyTensor.

If left_tensor is supplied, computes

… math:

\begin{equation}
    L A^{-1} R,
\end{equation}

where \(L\) is left_tensor. Supplying this can reduce the number of CG calls required.

:param - torch.tensor: :type - torch.tensor: n x k :param - torch.tensor: :type - torch.tensor: m x n

Returns:
  • torch.tensor - \(A^{-1}R\) or \(LA^{-1}R\).
inv_quad(tensor, reduce_inv_quad=True)[source]

Computes an inverse quadratic form (w.r.t self) with several right hand sides. I.e. computes tr( tensor^T self^{-1} tensor )

NOTE: Don’t overwrite this function! Instead, overwrite inv_quad_logdet

Parameters:tensor (-) –
Returns:
  • tensor - tr( tensor^T (self)^{-1} tensor )
inv_quad_logdet(inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)[source]

Computes an inverse quadratic form (w.r.t self) with several right hand sides. I.e. computes tr( tensor^T self^{-1} tensor ) In addition, computes an (approximate) log determinant of the the matrix

Parameters:tensor (-) –
Returns:
  • scalar - tr( tensor^T (self)^{-1} tensor )
  • scalar - log determinant
logdet()[source]

Computes an (approximate) log determinant of the matrix

NOTE: Don’t overwrite this function! Instead, overwrite inv_quad_logdet

Returns:log determinant
Return type:
  • scalar
matmul(other)[source]

Multiplies self by a matrix

Parameters:other (torch.tensor) – Matrix or vector to multiply with. Can be either a torch.tensor or a gpytorch.lazy.LazyTensor.
Returns:Tensor or LazyTensor containing the result of the matrix multiplication \(KM\), where \(K\) is the (batched) matrix that this gpytorch.lazy.LazyTensor represents, and \(M\) is the (batched) matrix input to this method.
Return type:torch.tensor
matrix_shape

Returns the shape of the matrix being represented (without batching).

mul(other)[source]

Multiplies the matrix by a constant, or elementwise the matrix by another matrix

Parameters:
  • other (torch.tensor or LazyTensor) – constant or matrix to elementwise
  • by. (multiply) –
Returns:

Another lazy tensor representing the result of the multiplication. if other was a constant (or batch of constants), this will likely be a gpytorch.lazy.ConstantMulLazyTensor. If other was another matrix, this will likely be a gpytorch.lazy.MulLazyTensor.

Return type:

gpytorch.lazy.LazyTensor

ndimension()[source]

Returns the number of dimensions

numel()[source]

Returns the number of elements

numpy()[source]

Return self as an evaluated numpy array

pivoted_cholesky(rank, error_tol=None, return_pivots=False)[source]

Performs a partial pivoted Cholesky factorization of the (positive definite) LazyTensor. \(\mathbf L \mathbf L^\top = \mathbf K\). The partial pivoted Cholesky factor \(\mathbf L \in \mathbb R^{N \times \text{rank}}\) forms a low rank approximation to the LazyTensor.

The pivots are selected greedily, corresponding to the maximum diagonal element in the residual after each Cholesky iteration. See Harbrecht et al., 2012.

Parameters:
  • rank (int) – The size of the partial pivoted Cholesky factor.
  • error_tol (float, optional) – Defines an optional stopping criterion. If the residual of the factorization is less than error_tol, then the factorization will exit early. This will result in a \(\leq \text{ rank}\) factor.
  • return_pivots (bool) – (default: False) Whether or not to return the pivots alongside the partial pivoted Cholesky factor.
Returns:

the … x N x rank factor (and optionally the … x N pivots)

Return type:

torch.Tensor or tuple(torch.Tensor, torch.Tensor)

prod(dim=None)[source]

For a b x n x m LazyTensor, compute the product over the batch dimension.

The mul_batch_size controls whether or not the batch dimension is grouped when multiplying.
  • mul_batch_size=None (default): The entire batch dimension is multiplied. Returns a n x n LazyTensor.
  • mul_batch_size=k: Creates b/k groups, and muls the k entries of this group.
    (The LazyTensor is reshaped as a b/k x k x n x m LazyTensor and the k dimension is multiplied over. Returns a b/k x n x m LazyTensor.
Parameters:mul_batch_size (int or None) – Controls the number of groups that are multiplied over (default: None).
Returns:LazyTensor

Example

>>> lazy_tensor = gpytorch.lazy.NonLazyTensor(torch.tensor([
        [[2, 4], [1, 2]],
        [[1, 1], [0, -1]],
        [[2, 1], [1, 0]],
        [[3, 2], [2, -1]],
    ]))
>>> lazy_tensor.mul_batch().evaluate()
>>> # Returns: torch.Tensor([[12, 8], [0, 0]])
>>> lazy_tensor.mul_batch(mul_batch_size=2)
>>> # Returns: torch.Tensor([[[2, 4], [0, -2]], [[6, 2], [2, 0]]])
repeat(*sizes)[source]

Repeats this tensor along the specified dimensions.

Currently, this only works to create repeated batches of a 2D LazyTensor. I.e. all calls should be lazy_tensor.repeat(<size>, 1, 1).

Example

>>> lazy_tensor = gpytorch.lazy.ToeplitzLazyTensor(torch.tensor([4. 1., 0.5]))
>>> lazy_tensor.repeat(2, 1, 1).evaluate()
tensor([[[4.0000, 1.0000, 0.5000],
         [1.0000, 4.0000, 1.0000],
         [0.5000, 1.0000, 4.0000]],
        [[4.0000, 1.0000, 0.5000],
         [1.0000, 4.0000, 1.0000],
         [0.5000, 1.0000, 4.0000]]])
representation()[source]

Returns the Tensors that are used to define the LazyTensor

representation_tree()[source]

Returns a gpytorch.lazy.LazyTensorRepresentationTree tree object that recursively encodes the representation of this lazy tensor. In particular, if the definition of this lazy tensor depends on other lazy tensors, the tree is an object that can be used to reconstruct the full structure of this lazy tensor, including all subobjects. This is used internally.

requires_grad_(val)[source]

Sets requires_grad=val on all the Tensors that make up the LazyTensor This is an inplace operation.

rmatmul(other)[source]

Multiplies a matrix by self.

Parameters:other (torch.tensor) – Matrix or vector to multiply with. Can be either a torch.tensor or a gpytorch.lazy.LazyTensor.
Returns:Tensor or LazyTensor containing the result of the matrix multiplication \(MK\), where \(M\) is the (batched) matrix input to this method, and \(K\) is the (batched) matrix that this gpytorch.lazy.LazyTensor represents.
Return type:torch.tensor
root_decomposition(method: Optional[str] = None)[source]

Returns a (usually low-rank) root decomposition lazy tensor of a PSD matrix. This can be used for sampling from a Gaussian distribution, or for obtaining a low-rank version of a matrix

root_inv_decomposition(initial_vectors=None, test_vectors=None, method: Optional[str] = None)[source]

Returns a (usually low-rank) root decomposotion lazy tensor of a PSD matrix. This can be used for sampling from a Gaussian distribution, or for obtaining a low-rank version of a matrix

size(val=None)[source]

Returns the size of the resulting Tensor that the lazy tensor represents

sqrt_inv_matmul(rhs, lhs=None)[source]

If A is positive definite, computes either lhs A^{-1/2} rhs or A^{-1/2} rhs.

sum(dim=None)[source]

Sum the LazyTensor across a dimension. The dim controls which batch dimension is summed over. If set to None, then sums all dimensions

Parameters:dim (int) – Which dimension is being summed over (default=None)
Returns:LazyTensor or Tensor.

Example

>>> lazy_tensor = gpytorch.lazy.NonLazyTensor(torch.tensor([
        [[2, 4], [1, 2]],
        [[1, 1], [0, -1]],
        [[2, 1], [1, 0]],
        [[3, 2], [2, -1]],
    ]))
>>> lazy_tensor.sum(0).evaluate()
svd() → Tuple[gpytorch.lazy.lazy_tensor.LazyTensor, torch.Tensor, gpytorch.lazy.lazy_tensor.LazyTensor][source]

Compute the SVD of the lazy tensor M s.t. M = U @ S @ V.T. This can be very slow for large tensors. Should be special-cased for tensors with particular structure. Does NOT sort the sigular values.

Returns:Tuple containing the left singular vectors (U), the singular values (S), and the right singular vectors (V).
symeig(eigenvectors: bool = False) → Tuple[torch.Tensor, Optional[gpytorch.lazy.lazy_tensor.LazyTensor]][source]

Compute the symmetric eigendecomposition of the lazy tensor. This can be very slow for large tensors. Should be special-cased for tensors with particular structure. Does NOT sort the eigenvalues.

Parameters:eigenvectors (bool) – If True, compute the eigenvectors in addition to the eigenvalues.
Returns:Tuple containing the eigenvalues and eigenvectors. If eigenvectors=False, this is None. Otherwise, this LazyTensor contains the orthonormal eigenvectors of the matrix.
t()[source]

Alias of transpose() for 2D LazyTensor. (Tranposes the two dimensions.)

to(*args, **kwargs)[source]

A device-agnostic method of moving the lazy_tensor to the specified device or dtype. Note that we do NOT support non_blocking or other torch.to options other than device and dtype and these options will be silently ignored.

Parameters:
  • ( (dtype) – obj: torch.device): Which device to use (GPU or CPU).
  • ( – obj: torch.dtype): Which dtype to use (double, float, or half).
Returns:

New LazyTensor identical to self on specified device

Return type:

LazyTensor

transpose(dim1, dim2)[source]

Transpose the dimensions dim1 and dim2 of the LazyTensor.

Example

>>> lazy_tensor = gpytorch.lazy.NonLazyTensor(torch.randn(3, 5))
>>> lazy_tensor.transpose(0, 1)
type(dtype)[source]

This method operates similarly to torch.Tensor.type().

zero_mean_mvn_samples(num_samples)[source]

Assumes that self is a covariance matrix, or a batch of covariance matrices. Returns samples from a zero-mean MVN, defined by self (as covariance matrix)

Self should be symmetric, either (batch_size x num_dim x num_dim) or (num_dim x num_dim)

Parameters:num_samples (int) – Number of samples to draw.
Returns:Samples from MVN (num_samples x batch_size x num_dim) or (num_samples x num_dim)
Return type:torch.tensor
class gpytorch.lazy.BlockLazyTensor(base_lazy_tensor, block_dim=-3)[source]

An abstract LazyTensor class for block tensors. Super classes will determine how the different blocks are layed out (e.g. block diagonal, sum over blocks, etc.)

BlockLazyTensors represent the groups of blocks as a batched Tensor. The block_dim attribute specifies which dimension of the base LazyTensor specifies the blocks. For example, (with block_dim=-3 a k x n x n tensor represents k n x n blocks. A b x k x n x n tensor represents k b x n x n blocks.

Parameters:
  • base_lazy_tensor (LazyTensor or Tensor) – Must be at least 3 dimenional.
  • block_dim (int) – The dimension that specifies blocks.

Kernel LazyTensors

class gpytorch.lazy.LazyEvaluatedKernelTensor(x1, x2, kernel, last_dim_is_batch=False, **params)[source]
diag()[source]

Getting the diagonal of a kernel can be handled more efficiently by transposing the batch and data dimension before calling the kernel. Implementing it this way allows us to compute predictions more efficiently in cases where only the variances are required.

evaluate_kernel()[source]

NB: This is a meta LazyTensor, in the sense that evaluate can return a LazyTensor if the kernel being evaluated does so.

Structured LazyTensors

BlockDiagLazyTensor

class gpytorch.lazy.BlockDiagLazyTensor(base_lazy_tensor, block_dim=-3)[source]

Represents a lazy tensor that is the block diagonal of square matrices. The block_dim attribute specifies which dimension of the base LazyTensor specifies the blocks. For example, (with block_dim=-3 a k x n x n tensor represents k n x n blocks (a kn x kn matrix). A b x k x n x n tensor represents k b x n x n blocks (a b x kn x kn batch matrix).

Parameters:
  • base_lazy_tensor (LazyTensor or Tensor) – Must be at least 3 dimensional.
  • block_dim (int) – The dimension that specifies the blocks.

CatLazyTensor

class gpytorch.lazy.CatLazyTensor(*lazy_tensors, dim=0, output_device=None)[source]

A LazyTensor that represents the concatenation of other lazy tensors. Each LazyTensor must have the same shape except in the concatenating dimension.

Parameters:
  • lazy_tensors (list of LazyTensors) – A list of LazyTensors whose sizes are the same except in concatenating dimension dim
  • dim (int) – The concatenating dimension which can be a batch dimension.
  • output_device (torch.device) – The CatLazyTensor will appear to appear on output_device and place any output torch.Tensors on output_device
all_to(device_id)[source]

Create a new CatLazyTensor with all LazyTensors in CatLazyTensor moved to one device device. The new CatLazyTensor also has device_id as the output_device.

to(*args, **kwargs)[source]

Returns a new CatLazyTensor with device as the output_device and dtype as the dtype. Warning: this does not move the LazyTensors in this CatLazyTensor to device.

CholLazyTensor

class gpytorch.lazy.CholLazyTensor(chol: gpytorch.lazy.triangular_lazy_tensor._TriangularLazyTensorBase, upper: bool = False)[source]

DiagLazyTensor

class gpytorch.lazy.DiagLazyTensor(diag)[source]

MatmulLazyTensor

class gpytorch.lazy.MatmulLazyTensor(left_lazy_tensor, right_lazy_tensor)[source]

RootLazyTensor

class gpytorch.lazy.RootLazyTensor(root)[source]

NonLazyTensor

class gpytorch.lazy.NonLazyTensor(tsr)[source]

ToeplitzLazyTensor

class gpytorch.lazy.ToeplitzLazyTensor(column)[source]
diag()[source]

Gets the diagonal of the Toeplitz matrix wrapped by this object.

ZeroLazyTensor

class gpytorch.lazy.ZeroLazyTensor(*sizes, dtype=None, device=None)[source]

Special LazyTensor representing zero.

Composition/Decoration LazyTensors

AddedDiagLazyTensor

class gpytorch.lazy.AddedDiagLazyTensor(*lazy_tensors, preconditioner_override=None)[source]

A SumLazyTensor, but of only two lazy tensors, the second of which must be a DiagLazyTensor.

evaluate_kernel()[source]

Overriding this is currently necessary to allow for subclasses of AddedDiagLT to be created. For example, consider the following:

>>> covar1 = covar_module(x).add_diag(torch.tensor(1.)).evaluate_kernel()
>>> covar2 = covar_module(x).evaluate_kernel().add_diag(torch.tensor(1.))

Unless we override this method (or find a better solution), covar1 and covar2 might not be the same type. In particular, covar1 would always be a standard AddedDiagLazyTensor, but covar2 might be a subtype.

ConstantMulLazyTensor

class gpytorch.lazy.ConstantMulLazyTensor(base_lazy_tensor, constant)[source]

A LazyTensor that multiplies a base LazyTensor by a scalar constant:

` constant_mul_lazy_tensor = constant * base_lazy_tensor `

Note

To element-wise multiply two lazy tensors, see gpytorch.lazy.MulLazyTensor

Parameters:
  • base_lazy_tensor (LazyTensor) or (b x n x m)) – The base_lazy tensor
  • constant (Tensor) – The constant

If base_lazy_tensor represents a matrix (non-batch), then constant must be a 0D tensor, or a 1D tensor with one element.

If base_lazy_tensor represents a batch of matrices (b x m x n), then constant can be either: - A 0D tensor - the same constant is applied to all matrices in the batch - A 1D tensor with one element - the same constant is applied to all matrices - A 1D tensor with b elements - a different constant is applied to each matrix

Example:

>>> base_base_lazy_tensor = gpytorch.lazy.ToeplitzLazyTensor([1, 2, 3])
>>> constant = torch.tensor(1.2)
>>> new_base_lazy_tensor = gpytorch.lazy.ConstantMulLazyTensor(base_base_lazy_tensor, constant)
>>> new_base_lazy_tensor.evaluate()
>>> # Returns:
>>> # [[ 1.2, 2.4, 3.6 ]
>>> #  [ 2.4, 1.2, 2.4 ]
>>> #  [ 3.6, 2.4, 1.2 ]]
>>>
>>> base_base_lazy_tensor = gpytorch.lazy.ToeplitzLazyTensor([[1, 2, 3], [2, 3, 4]])
>>> constant = torch.tensor([1.2, 0.5])
>>> new_base_lazy_tensor = gpytorch.lazy.ConstantMulLazyTensor(base_base_lazy_tensor, constant)
>>> new_base_lazy_tensor.evaluate()
>>> # Returns:
>>> # [[[ 1.2, 2.4, 3.6 ]
>>> #   [ 2.4, 1.2, 2.4 ]
>>> #   [ 3.6, 2.4, 1.2 ]]
>>> #  [[ 1, 1.5, 2 ]
>>> #   [ 1.5, 1, 1.5 ]
>>> #   [ 2, 1.5, 1 ]]]

InterpolatedLazyTensor

class gpytorch.lazy.InterpolatedLazyTensor(base_lazy_tensor, left_interp_indices=None, left_interp_values=None, right_interp_indices=None, right_interp_values=None)[source]

KroneckerProductLazyTensor

class gpytorch.lazy.KroneckerProductLazyTensor(*lazy_tensors)[source]

Returns the Kronecker product of the given lazy tensors

:param : lazy_tensors: List of lazy tensors

add_diag(diag)[source]

Adds a diagonal to a KroneckerProductLazyTensor

diag()[source]

As torch.diag(), returns the diagonal of the matrix \(K\) this LazyTensor represents as a vector.

Return type:torch.tensor
Returns:The diagonal of \(K\). If \(K\) is \(n \times n\), this will be a length n vector. If this LazyTensor represents a batch (e.g., is \(b \times n \times n\)), this will be a \(b \times n\) matrix of diagonals, one for each matrix in the batch.

MulLazyTensor

class gpytorch.lazy.MulLazyTensor(left_lazy_tensor, right_lazy_tensor)[source]
representation()[source]

Returns the Tensors that are used to define the LazyTensor

PsdSumLazyTensor

class gpytorch.lazy.PsdSumLazyTensor(*lazy_tensors, **kwargs)[source]

A SumLazyTensor, but where every component of the sum is positive semi-definite

SumBatchLazyTensor

class gpytorch.lazy.SumBatchLazyTensor(base_lazy_tensor, block_dim=-3)[source]

Represents a lazy tensor that is actually the sum of several lazy tensors blocks. The block_dim attribute specifies which dimension of the base LazyTensor specifies the blocks. For example, (with block_dim=-3 a k x n x n tensor represents k n x n blocks (a n x n matrix). A b x k x n x n tensor represents k b x n x n blocks (a b x n x n batch matrix).

Parameters:
  • base_lazy_tensor (LazyTensor) – A k x n x n LazyTensor, or a b x k x n x n LazyTensor.
  • block_dim (int) – The dimension that specifies the blocks.