Source code for gpytorch.utils.permutation

#!/usr/bin/env python3

from typing import Optional, Union

import torch


[docs]def apply_permutation( matrix: Union["LazyTensor", torch.Tensor], # noqa: F821 left_permutation: Optional[torch.Tensor] = None, right_permutation: Optional[torch.Tensor] = None, ): r""" Applies a left and/or right (partial) permutation to a given matrix :math:`\mathbf K`: .. math:: \begin{equation} \boldsymbol{\Pi}_\text{left} \mathbf K \boldsymbol{\Pi}_\text{right}^\top \end{equation} where the permutation matrices :math:`\boldsymbol{\Pi}_\text{left}` and :math:`\boldsymbol{\Pi}_\text{right}^\top` are represented by vectors :attr:`left_permutation` and :attr:`right_permutation`. The permutation matrices may be partial permutations (only selecting a subset of rows/columns) or full permutations (permuting all rows/columns). Importantly, if :math:`\mathbf K` is a batch of matrices, :attr:`left_permutation` and :attr:`right_permutation` can be a batch of permutation vectors, and this function will apply the appropriate permutation to each batch entry. Broadcasting rules apply. :param matrix: :math:`\mathbf K` :type matrix: ~gpytorch.lazy.LazyTensor or ~torch.Tensor (... x n x n) :param left_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{left}` :type left_permutation: ~torch.Tensor, optional (... x <= n) :param right_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{right}` :type right_permutation: ~torch.Tensor, optional (... x <= n) :return: :math:`\boldsymbol{\Pi}_\text{left} \mathbf K \boldsymbol{\Pi}_\text{right}^\top` :rtype: ~torch.Tensor Example: >>> _factor = torch.randn(2, 3, 5, 5) >>> matrix = factor @ factor.transpose(-1, -2) # 2 x 3 x 5 x 5 >>> left_permutation = torch.tensor([ >>> [ 1, 3, 2, 4, 0 ], >>> [ 2, 1, 0, 3, 4 ], >>> [ 0, 1, 2, 4, 3 ], >>> ]) # Full permutation: 2 x 3 x 5 >>> right_permutation = torch.tensor([ >>> [ 1, 3, 2 ], >>> [ 2, 1, 0 ], >>> [ 0, 1, 2 ], >>> ]) # Partial permutation: 2 x 3 x 3 >>> apply_permutation(matrix, left_permutation, right_permutation) # 2 x 3 x 5 x 3 """ from ..lazy import delazify if left_permutation is None and right_permutation is None: return delazify(matrix) # Create a set of index vectors for each batch dimension # This will ensure that the indexing operations with left_permutation and right_permutation # only select the elements from the appropriate batch batch_shape = matrix.shape[:-2] batch_idx = [] for i, batch_size in enumerate(batch_shape): expanded_shape = [1 for _ in batch_shape] + [1, 1] expanded_shape[i] = batch_size sub_batch_idx = torch.arange(batch_size, device=matrix.device).view(*expanded_shape) batch_idx.append(sub_batch_idx) # If we don't have a left_permutation vector, we'll just use a slice if left_permutation is None: left_permutation = torch.arange(matrix.size(-2), device=matrix.device) if right_permutation is None: right_permutation = torch.arange(matrix.size(-1), device=matrix.device) # Apply permutations return delazify(matrix.__getitem__((*batch_idx, left_permutation.unsqueeze(-1), right_permutation.unsqueeze(-2))))
[docs]def inverse_permutation(permutation): r""" Given a (batch of) permutation vector(s), return a permutation vector that inverts the original permutation. Example: >>> permutation = torch.tensor([ 1, 3, 2, 4, 0 ]) >>> inverse_permutation(permutation) # torch.tensor([ 4, 0, 2, 1, 3 ]) """ arange = torch.arange(permutation.size(-1), device=permutation.device) res = torch.zeros_like(permutation).scatter_(-1, permutation, arange.expand_as(permutation)) return res