gpytorch.lazy¶
LazyTensor¶
-
class
gpytorch.lazy.
LazyTensor
(*args, **kwargs)[source]¶ Base class for LazyTensors in GPyTorch.
In GPyTorch, nearly all covariance matrices for Gaussian processes are handled internally as some variety of LazyTensor. A LazyTensor is an object that represents a tensor object, similar to
torch.tensor
, but typically differs in two ways:- A tensor represented by a LazyTensor can typically be represented more efficiently than storing a full matrix. For example, a LazyTensor representing \(K=XX^{\top}\) where \(K\) is \(n \times n\) but \(X\) is \(n \times d\) might store \(X\) instead of \(K\) directly.
- A LazyTensor typically defines a matmul routine that performs \(KM\) that is more efficient than storing the full matrix. Using the above example, performing \(KM=X(X^{\top}M)\) requires only \(O(nd)\) time, rather than the \(O(n^2)\) time required if we were storing \(K\) directly.
In order to define a new LazyTensor class that can be used as a covariance matrix in GPyTorch, a user must define at a minimum the following methods (in each example, \(K\) denotes the matrix that the LazyTensor represents)
_matmul()
, which performs a matrix multiplication \(KM\)_size()
, which returns atorch.Size
containing the dimensions of \(K\)._transpose_nonbatch()
, which returns a transposed version of the LazyTensor
In addition to these, the following methods should be implemented for maximum efficiency
_quad_form_derivative()
, which computes the derivative of a quadratic form with the LazyTensor (e.g. \(d (a^T X b) / dX\))._get_indices()
, which returns atorch.Tensor
containing elements that are given by various tensor indices._expand_batch()
, which expands the batch dimensions of LazyTensors._check_args()
, which performs error checking on the arguments supplied to the LazyTensor constructor.
In addition to these, a LazyTensor may need to define the following functions if it does anything interesting with the batch dimensions (e.g. sums along them, adds additional ones, etc):
_unsqueeze_batch()
,_getitem()
, and_permute_batch()
. See the documentation for these methods for details.Note
The base LazyTensor class provides default implementations of many other operations in order to mimic the behavior of a standard tensor as closely as possible. For example, we provide default implementations of
__getitem__()
,__add__()
, etc that either make use of other lazy tensors or exploit the functions that must be defined above.Rather than overriding the public methods, we recommend that you override the private versions associated with these methods (e.g. - write a custom _getitem verses a custom __getitem__). This is because the public methods do quite a bit of error checking and casing that doesn’t need to be repeated.
Note
LazyTensors are designed by default to optionally represent batches of matrices. Thus, the size of a LazyTensor may be (for example) \(b \times n \times n\). Many of the methods are designed to efficiently operate on these batches if present.
-
add_jitter
(jitter_val=0.001)[source]¶ Adds jitter (i.e., a small diagonal component) to the matrix this LazyTensor represents. This could potentially be implemented as a no-op, however this could lead to numerical instabilities, so this should only be done at the user’s risk.
-
add_low_rank
(low_rank_mat, root_decomp_method: Optional[str] = None, root_inv_decomp_method: Optional[str] = None, generate_roots: Optional[bool] = True, **root_decomp_kwargs)[source]¶ Adds a low rank matrix to the matrix that this LazyTensor represents, e.g. computes A + BB^T. We then update both the tensor and its root decomposition.
We have access to, L and M where A pprox LL^T and A^{-1} pprox MM^T. We compute ilde{A} = A + BB^T = L(I + M B B^T M^T)L’ and then decompose (I + M VV^T M^T) pprox RR^T, using LR as our new root decomposition.
This strategy is described in more detail in “Kernel Interpolation for Scalable Online Gaussian Processes,” Stanton et al, AISTATS, 2021. https://arxiv.org/abs/2103.01454.
- Args:
- low_rank_mat (
torch.tensor
): the matrix B that we are adding to A. root_decomp_method (str
): how to compute the root decomposition of A. root_inv_decomp_method (str
): how to compute the root inverse decomposition of A. generate_roots (bool
): whether to generate the root decomposition of \(A\) even if it has not been created yet. - Returns:
SumLazyTensor
: addition of A and BB^T.
-
batch_dim
¶ Returns the dimension of the shape over which the tensor is batched.
-
batch_shape
¶ Returns the shape over which the tensor is batched.
-
cat_rows
(cross_mat, new_mat, generate_roots=True, generate_inv_roots=True, **root_decomp_kwargs)[source]¶ Concatenates new rows and columns to the matrix that this LazyTensor represents, e.g. C = [A B^T; B D]. where A is the existing lazy tensor, and B (cross_mat) and D (new_mat) are new components. This is most commonly used when fantasizing with kernel matrices.
We have access to A approx LL^T and A^{-1} approx RR^T, where L and R are low rank matrices resulting from root and root inverse decompositions (see the LOVE paper).
- To update R, we first update L:
- [A B^T; B D] = [E 0; F G][E^T F^T; 0 G^T]
- Solving this matrix equation, we get:
- A = EE^T = LL^T ==> E = L B = EF^T ==> F = BR D = FF^T + GG^T ==> G = (D - FF^T)^{1/2}
Once we’ve computed Z = [E 0; F G], we have that the new kernel matrix [K U; U^T S] pprox ZZ^T. Therefore, we can form a pseudo-inverse of Z directly to approximate [K U; U^T S]^{-1/2}.
This strategy is also described in “Efficient Nonmyopic Bayesian Optimization via One-Shot Multistep Trees,” Jiang et al, NeurIPS, 2020. https://arxiv.org/abs/2006.15779.
- Args:
- cross_mat (
torch.tensor
): the matrix \(B\) we are appending to the matrix \(A\). - If \(A\) is n x n, then this matrix should be n x k.
- new_mat (
torch.tensor
): the matrix \(D\) we are appending to the matrix \(A\). - If \(B\) is n x k, then this matrix should be k x k.
- generate_roots (
bool
): whether to generate the root - decomposition of \(A\) even if it has not been created yet.
- generate_inv_roots (
bool
): whether to generate the root inv - decomposition of \(A\) even if it has not been created yet.
- cross_mat (
- Returns:
LazyTensor
: concatenated lazy tensor with the new rows and columns.
-
cholesky
(upper=False)[source]¶ Cholesky-factorizes the LazyTensor
- Parameters:
- upper (bool) - upper triangular or lower triangular factor (default: False)
- Returns:
- (LazyTensor) Cholesky factor (triangular, upper/lower depending on “upper” arg)
-
cpu
()[source]¶ - Returns:
LazyTensor
: a new LazyTensor identical toself
, but on the CPU.
-
cuda
(device_id=None)[source]¶ This method operates identically to
torch.nn.Module.cuda()
.- Args:
- device_id (
str
, optional): - Device ID of GPU to use.
- device_id (
- Returns:
LazyTensor
:- a new LazyTensor identical to
self
, but on the GPU.
-
detach
()[source]¶ Removes the LazyTensor from the current computation graph. (In practice, this function removes all Tensors that make up the LazyTensor from the computation graph.)
-
diag
()[source]¶ As
torch.diag()
, returns the diagonal of the matrix \(K\) this LazyTensor represents as a vector.Return type: torch.tensor Returns: The diagonal of \(K\). If \(K\) is \(n \times n\), this will be a length n vector. If this LazyTensor represents a batch (e.g., is \(b \times n \times n\)), this will be a \(b \times n\) matrix of diagonals, one for each matrix in the batch.
-
diagonalization
(method: Optional[str] = None)[source]¶ Returns a (usually partial) diagonalization of a symmetric PSD matrix. Options are either “lanczos” or “symeig”. “lanczos” runs Lanczos while “symeig” runs LazyTensor.symeig.
-
dim
()[source]¶ Alias of
ndimension()
-
evaluate
()[source]¶ Explicitly evaluates the matrix this LazyTensor represents. This function should return a Tensor storing an exact representation of this LazyTensor.
-
evaluate_kernel
()[source]¶ Return a new LazyTensor representing the same one as this one, but with all lazily evaluated kernels actually evaluated.
-
inv_matmul
(right_tensor, left_tensor=None)[source]¶ Computes a linear solve (w.r.t self = \(A\)) with several right hand sides \(R\). I.e. computes
… math:
\begin{equation} A^{-1} R, \end{equation}
where \(R\) is
right_tensor
and \(A\) is the LazyTensor.If
left_tensor
is supplied, computes… math:
\begin{equation} L A^{-1} R, \end{equation}
where \(L\) is
left_tensor
. Supplying this can reduce the number of CG calls required.- Args:
torch.tensor
(n x k) - Matrix \(R\) right hand sidestorch.tensor
(m x n) - Optional matrix \(L\) to perform left multiplication with
- Returns:
torch.tensor
- \(A^{-1}R\) or \(LA^{-1}R\).
-
inv_quad
(tensor, reduce_inv_quad=True)[source]¶ Computes an inverse quadratic form (w.r.t self) with several right hand sides. I.e. computes tr( tensor^T self^{-1} tensor )
NOTE: Don’t overwrite this function! Instead, overwrite inv_quad_logdet
- Args:
- tensor (tensor nxk) - Vector (or matrix) for inverse quad
- Returns:
- tensor - tr( tensor^T (self)^{-1} tensor )
-
inv_quad_logdet
(inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)[source]¶ Computes an inverse quadratic form (w.r.t self) with several right hand sides. I.e. computes tr( tensor^T self^{-1} tensor ) In addition, computes an (approximate) log determinant of the the matrix
- Args:
- tensor (tensor nxk) - Vector (or matrix) for inverse quad
- Returns:
- scalar - tr( tensor^T (self)^{-1} tensor )
- scalar - log determinant
-
logdet
()[source]¶ Computes an (approximate) log determinant of the matrix
NOTE: Don’t overwrite this function! Instead, overwrite inv_quad_logdet
- Returns:
- scalar: log determinant
-
matmul
(other)[source]¶ Multiplies self by a matrix
- Args:
- other (
torch.tensor
): Matrix or vector to multiply with. Can be either atorch.tensor
- or a
gpytorch.lazy.LazyTensor
.
- other (
- Returns:
torch.tensor
: Tensor or LazyTensor containing the result of the matrix multiplication \(KM\), where \(K\) is the (batched) matrix that thisgpytorch.lazy.LazyTensor
represents, and \(M\) is the (batched) matrix input to this method.
-
matrix_shape
¶ Returns the shape of the matrix being represented (without batching).
-
mul
(other)[source]¶ Multiplies the matrix by a constant, or elementwise the matrix by another matrix
- Args:
- other (
torch.tensor
orLazyTensor
): constant or matrix to elementwise multiply by. - Returns:
gpytorch.lazy.LazyTensor
: Another lazy tensor representing the result of the multiplication. if other was a constant (or batch of constants), this will likely be agpytorch.lazy.ConstantMulLazyTensor
. If other was another matrix, this will likely be agpytorch.lazy.MulLazyTensor
.
-
pivoted_cholesky
(rank, error_tol=None, return_pivots=False)[source]¶ Performs a partial pivoted Cholesky factorization of the (positive definite) LazyTensor. \(\mathbf L \mathbf L^\top = \mathbf K\). The partial pivoted Cholesky factor \(\mathbf L \in \mathbb R^{N \times \text{rank}}\) forms a low rank approximation to the LazyTensor.
The pivots are selected greedily, corresponding to the maximum diagonal element in the residual after each Cholesky iteration. See Harbrecht et al., 2012.
Parameters: - rank (int) – The size of the partial pivoted Cholesky factor.
- error_tol (float, optional) – Defines an optional stopping criterion.
If the residual of the factorization is less than
error_tol
, then the factorization will exit early. This will result in a \(\leq \text{ rank}\) factor. - return_pivots (bool) – (default: False) Whether or not to return the pivots alongside the partial pivoted Cholesky factor.
Returns: the … x N x rank factor (and optionally the … x N pivots)
Return type: torch.Tensor or tuple(torch.Tensor, torch.Tensor)
-
prod
(dim=None)[source]¶ For a b x n x m LazyTensor, compute the product over the batch dimension.
- The mul_batch_size controls whether or not the batch dimension is grouped when multiplying.
- mul_batch_size=None (default): The entire batch dimension is multiplied. Returns a n x n LazyTensor.
- mul_batch_size=k: Creates b/k groups, and muls the k entries of this group.
- (The LazyTensor is reshaped as a b/k x k x n x m LazyTensor and the k dimension is multiplied over. Returns a b/k x n x m LazyTensor.
- Args:
mul_batch_size
(int or None):- Controls the number of groups that are multiplied over (default: None).
- Returns:
LazyTensor
- Example:
>>> lazy_tensor = gpytorch.lazy.NonLazyTensor(torch.tensor([ [[2, 4], [1, 2]], [[1, 1], [0, -1]], [[2, 1], [1, 0]], [[3, 2], [2, -1]], ])) >>> lazy_tensor.mul_batch().evaluate() >>> # Returns: torch.Tensor([[12, 8], [0, 0]]) >>> lazy_tensor.mul_batch(mul_batch_size=2) >>> # Returns: torch.Tensor([[[2, 4], [0, -2]], [[6, 2], [2, 0]]])
-
repeat
(*sizes)[source]¶ Repeats this tensor along the specified dimensions.
Currently, this only works to create repeated batches of a 2D LazyTensor. I.e. all calls should be lazy_tensor.repeat(<size>, 1, 1).
- Example:
>>> lazy_tensor = gpytorch.lazy.ToeplitzLazyTensor(torch.tensor([4. 1., 0.5])) >>> lazy_tensor.repeat(2, 1, 1).evaluate() tensor([[[4.0000, 1.0000, 0.5000], [1.0000, 4.0000, 1.0000], [0.5000, 1.0000, 4.0000]], [[4.0000, 1.0000, 0.5000], [1.0000, 4.0000, 1.0000], [0.5000, 1.0000, 4.0000]]])
-
representation_tree
()[source]¶ Returns a
gpytorch.lazy.LazyTensorRepresentationTree
tree object that recursively encodes the representation of this lazy tensor. In particular, if the definition of this lazy tensor depends on other lazy tensors, the tree is an object that can be used to reconstruct the full structure of this lazy tensor, including all subobjects. This is used internally.
-
requires_grad_
(val)[source]¶ Sets requires_grad=val on all the Tensors that make up the LazyTensor This is an inplace operation.
-
rmatmul
(other)[source]¶ Multiplies a matrix by self.
- Args:
- other (
torch.tensor
): Matrix or vector to multiply with. Can be either atorch.tensor
- or a
gpytorch.lazy.LazyTensor
.
- other (
- Returns:
torch.tensor
: Tensor or LazyTensor containing the result of the matrix multiplication \(MK\), where \(M\) is the (batched) matrix input to this method, and \(K\) is the (batched) matrix that thisgpytorch.lazy.LazyTensor
represents.
-
root_decomposition
(method: Optional[str] = None)[source]¶ Returns a (usually low-rank) root decomposition lazy tensor of a PSD matrix. This can be used for sampling from a Gaussian distribution, or for obtaining a low-rank version of a matrix
-
root_inv_decomposition
(initial_vectors=None, test_vectors=None, method: Optional[str] = None)[source]¶ Returns a (usually low-rank) root decomposotion lazy tensor of a PSD matrix. This can be used for sampling from a Gaussian distribution, or for obtaining a low-rank version of a matrix
-
sqrt_inv_matmul
(rhs, lhs=None)[source]¶ If A is positive definite, computes either lhs A^{-1/2} rhs or A^{-1/2} rhs.
-
sum
(dim=None)[source]¶ Sum the LazyTensor across a dimension. The dim controls which batch dimension is summed over. If set to None, then sums all dimensions
- Args:
dim
(int):- Which dimension is being summed over (default=None)
- Returns:
LazyTensor
or Tensor.- Example:
>>> lazy_tensor = gpytorch.lazy.NonLazyTensor(torch.tensor([ [[2, 4], [1, 2]], [[1, 1], [0, -1]], [[2, 1], [1, 0]], [[3, 2], [2, -1]], ])) >>> lazy_tensor.sum(0).evaluate()
-
svd
() → Tuple[gpytorch.lazy.lazy_tensor.LazyTensor, torch.Tensor, gpytorch.lazy.lazy_tensor.LazyTensor][source]¶ Compute the SVD of the lazy tensor M s.t. M = U @ S @ V.T. This can be very slow for large tensors. Should be special-cased for tensors with particular structure. Does NOT sort the sigular values.
- Returns:
LazyTensor
:- The left singular vectors (U).
torch.Tensor
:- The singular values (S).
LazyTensor
:- The right singular vectors (V).
-
symeig
(eigenvectors: bool = False) → Tuple[torch.Tensor, Optional[gpytorch.lazy.lazy_tensor.LazyTensor]][source]¶ Compute the symmetric eigendecomposition of the lazy tensor. This can be very slow for large tensors. Should be special-cased for tensors with particular structure. Does NOT sort the eigenvalues.
- Args:
eigenvectors
(bool): If True, compute the eigenvectors in addition to the eigenvalues.- Returns:
torch.Tensor
:- The eigenvalues.
LazyTensor
:- The eigenvectors. If eigenvectors=False, this is None. Otherwise, this LazyTensor contains the orthonormal eigenvectors of the matrix.
-
t
()[source]¶ Alias of
transpose()
for 2D LazyTensor. (Tranposes the two dimensions.)
-
to
(*args, **kwargs)[source]¶ A device-agnostic method of moving the lazy_tensor to the specified device or dtype. Note that we do NOT support non_blocking or other torch.to options other than device and dtype and these options will be silently ignored.
- Args:
- device (:obj: torch.device): Which device to use (GPU or CPU). dtype (:obj: torch.dtype): Which dtype to use (double, float, or half).
- Returns:
LazyTensor
: New LazyTensor identical to self on specified device
-
transpose
(dim1, dim2)[source]¶ Transpose the dimensions dim1 and dim2 of the LazyTensor.
- Example:
>>> lazy_tensor = gpytorch.lazy.NonLazyTensor(torch.randn(3, 5)) >>> lazy_tensor.transpose(0, 1)
-
zero_mean_mvn_samples
(num_samples)[source]¶ Assumes that self is a covariance matrix, or a batch of covariance matrices. Returns samples from a zero-mean MVN, defined by self (as covariance matrix)
Self should be symmetric, either (batch_size x num_dim x num_dim) or (num_dim x num_dim)
- Args:
num_samples
(int):- Number of samples to draw.
- Returns:
torch.tensor
:- Samples from MVN (num_samples x batch_size x num_dim) or (num_samples x num_dim)
-
class
gpytorch.lazy.
BlockLazyTensor
(base_lazy_tensor, block_dim=-3)[source]¶ An abstract LazyTensor class for block tensors. Super classes will determine how the different blocks are layed out (e.g. block diagonal, sum over blocks, etc.)
BlockLazyTensors represent the groups of blocks as a batched Tensor. The :attr:block_dim` attribute specifies which dimension of the base LazyTensor specifies the blocks. For example, (with block_dim=-3 a k x n x n tensor represents k n x n blocks. A b x k x n x n tensor represents k b x n x n blocks.
- Args:
base_lazy_tensor
(LazyTensor or Tensor):- Must be at least 3 dimenional.
block_dim
(int):- The dimension that specifies blocks.
Kernel LazyTensors¶
-
class
gpytorch.lazy.
LazyEvaluatedKernelTensor
(x1, x2, kernel, last_dim_is_batch=False, **params)[source]¶
Structured LazyTensors¶
BlockDiagLazyTensor¶
-
class
gpytorch.lazy.
BlockDiagLazyTensor
(base_lazy_tensor, block_dim=-3)[source]¶ Represents a lazy tensor that is the block diagonal of square matrices. The
block_dim
attribute specifies which dimension of the base LazyTensor specifies the blocks. For example, (with block_dim=-3 a k x n x n tensor represents k n x n blocks (a kn x kn matrix). A b x k x n x n tensor represents k b x n x n blocks (a b x kn x kn batch matrix).- Args:
base_lazy_tensor
(LazyTensor or Tensor):- Must be at least 3 dimensional.
block_dim
(int):- The dimension that specifies the blocks.
CatLazyTensor¶
-
class
gpytorch.lazy.
CatLazyTensor
(*lazy_tensors, dim=0, output_device=None)[source]¶ A LazyTensor that represents the concatenation of other lazy tensors. Each LazyTensor must have the same shape except in the concatenating dimension.
- Args:
lazy_tensors
(list of LazyTensors):- A list of LazyTensors whose sizes are the same except in
concatenating dimension
dim
dim
(int):- The concatenating dimension which can be a batch dimension.
output_device
(torch.device):- The CatLazyTensor will appear to appear on
output_device
and place any output torch.Tensors onoutput_device
CholLazyTensor¶
MatmulLazyTensor¶
ToeplitzLazyTensor¶
Composition/Decoration LazyTensors¶
AddedDiagLazyTensor¶
-
class
gpytorch.lazy.
AddedDiagLazyTensor
(*lazy_tensors, preconditioner_override=None)[source]¶ A SumLazyTensor, but of only two lazy tensors, the second of which must be a DiagLazyTensor.
-
evaluate_kernel
()[source]¶ Overriding this is currently necessary to allow for subclasses of AddedDiagLT to be created. For example, consider the following:
>>> covar1 = covar_module(x).add_diag(torch.tensor(1.)).evaluate_kernel() >>> covar2 = covar_module(x).evaluate_kernel().add_diag(torch.tensor(1.))
Unless we override this method (or find a better solution), covar1 and covar2 might not be the same type. In particular, covar1 would always be a standard AddedDiagLazyTensor, but covar2 might be a subtype.
-
ConstantMulLazyTensor¶
-
class
gpytorch.lazy.
ConstantMulLazyTensor
(base_lazy_tensor, constant)[source]¶ A LazyTensor that multiplies a base LazyTensor by a scalar constant:
` constant_mul_lazy_tensor = constant * base_lazy_tensor `
Note
To element-wise multiply two lazy tensors, see
gpytorch.lazy.MulLazyTensor
- Args:
- base_lazy_tensor (LazyTensor) or (b x n x m)): The base_lazy tensor constant (Tensor): The constant
If base_lazy_tensor represents a matrix (non-batch), then constant must be a 0D tensor, or a 1D tensor with one element.
If base_lazy_tensor represents a batch of matrices (b x m x n), then constant can be either: - A 0D tensor - the same constant is applied to all matrices in the batch - A 1D tensor with one element - the same constant is applied to all matrices - A 1D tensor with b elements - a different constant is applied to each matrix
Example:
>>> base_base_lazy_tensor = gpytorch.lazy.ToeplitzLazyTensor([1, 2, 3]) >>> constant = torch.tensor(1.2) >>> new_base_lazy_tensor = gpytorch.lazy.ConstantMulLazyTensor(base_base_lazy_tensor, constant) >>> new_base_lazy_tensor.evaluate() >>> # Returns: >>> # [[ 1.2, 2.4, 3.6 ] >>> # [ 2.4, 1.2, 2.4 ] >>> # [ 3.6, 2.4, 1.2 ]] >>> >>> base_base_lazy_tensor = gpytorch.lazy.ToeplitzLazyTensor([[1, 2, 3], [2, 3, 4]]) >>> constant = torch.tensor([1.2, 0.5]) >>> new_base_lazy_tensor = gpytorch.lazy.ConstantMulLazyTensor(base_base_lazy_tensor, constant) >>> new_base_lazy_tensor.evaluate() >>> # Returns: >>> # [[[ 1.2, 2.4, 3.6 ] >>> # [ 2.4, 1.2, 2.4 ] >>> # [ 3.6, 2.4, 1.2 ]] >>> # [[ 1, 1.5, 2 ] >>> # [ 1.5, 1, 1.5 ] >>> # [ 2, 1.5, 1 ]]]
InterpolatedLazyTensor¶
KroneckerProductLazyTensor¶
-
class
gpytorch.lazy.
KroneckerProductLazyTensor
(*lazy_tensors)[source]¶ Returns the Kronecker product of the given lazy tensors
- Args:
lazy_tensors: List of lazy tensors
-
diag
()[source]¶ As
torch.diag()
, returns the diagonal of the matrix \(K\) this LazyTensor represents as a vector.Return type: torch.tensor Returns: The diagonal of \(K\). If \(K\) is \(n \times n\), this will be a length n vector. If this LazyTensor represents a batch (e.g., is \(b \times n \times n\)), this will be a \(b \times n\) matrix of diagonals, one for each matrix in the batch.
MulLazyTensor¶
PsdSumLazyTensor¶
SumBatchLazyTensor¶
-
class
gpytorch.lazy.
SumBatchLazyTensor
(base_lazy_tensor, block_dim=-3)[source]¶ Represents a lazy tensor that is actually the sum of several lazy tensors blocks. The
block_dim
attribute specifies which dimension of the base LazyTensor specifies the blocks. For example, (with block_dim=-3 a k x n x n tensor represents k n x n blocks (a n x n matrix). A b x k x n x n tensor represents k b x n x n blocks (a b x n x n batch matrix).- Args:
base_lazy_tensor
(LazyTensor):- A k x n x n LazyTensor, or a b x k x n x n LazyTensor.
block_dim
(int):- The dimension that specifies the blocks.