Source code for gpytorch.kernels.grid_kernel

#!/usr/bin/env python3

import warnings
from typing import Optional

import torch
from linear_operator import to_dense
from linear_operator.operators import KroneckerProductLinearOperator, ToeplitzLinearOperator
from torch import Tensor

from .. import settings
from ..utils.grid import convert_legacy_grid, create_data_from_grid
from .kernel import Kernel


[docs]class GridKernel(Kernel): r""" If the input data :math:`X` are regularly spaced on a grid, then `GridKernel` can dramatically speed up computatations for stationary kernel. GridKernel exploits Toeplitz and Kronecker structure within the covariance matrix. See `Fast kernel learning for multidimensional pattern extrapolation`_ for more info. .. note:: `GridKernel` can only wrap **stationary kernels** (such as RBF, Matern, Periodic, Spectral Mixture, etc.) Args: base_kernel (Kernel): The kernel to speed up with grid methods. grid (Tensor): A g x d tensor where column i consists of the projections of the grid in dimension i. active_dims (tuple of ints, optional): Passed down to the `base_kernel`. interpolation_mode (bool): Used for GridInterpolationKernel where we want the covariance between points in the projections of the grid of each dimension. We do this by treating `grid` as d batches of g x 1 tensors by calling base_kernel(grid, grid) with last_dim_is_batch to get a d x g x g Tensor which we Kronecker product to get a g x g KroneckerProductLinearOperator. .. _Fast kernel learning for multidimensional pattern extrapolation: http://www.cs.cmu.edu/~andrewgw/manet.pdf """ is_stationary = True def __init__( self, base_kernel: Kernel, grid: Tensor, interpolation_mode: Optional[bool] = False, active_dims: Optional[bool] = None, ): if not base_kernel.is_stationary: raise RuntimeError("The base_kernel for GridKernel must be stationary.") super().__init__(active_dims=active_dims) if torch.is_tensor(grid): grid = convert_legacy_grid(grid) self.interpolation_mode = interpolation_mode self.base_kernel = base_kernel self.num_dims = len(grid) self.register_buffer_list("grid", grid) if not self.interpolation_mode: self.register_buffer("full_grid", create_data_from_grid(grid)) def _clear_cache(self): if hasattr(self, "_cached_kernel_mat"): del self._cached_kernel_mat
[docs] def register_buffer_list(self, base_name, tensors): """Helper to register several buffers at once under a single base name""" for i, tensor in enumerate(tensors): self.register_buffer(base_name + "_" + str(i), tensor)
@property def grid(self): return [getattr(self, f"grid_{i}") for i in range(self.num_dims)]
[docs] def update_grid(self, grid): """ Supply a new `grid` if it ever changes. """ if torch.is_tensor(grid): grid = convert_legacy_grid(grid) if len(grid) != self.num_dims: raise RuntimeError("New grid should have the same number of dimensions as before.") for i in range(self.num_dims): setattr(self, f"grid_{i}", grid[i]) if not self.interpolation_mode: self.full_grid = create_data_from_grid(self.grid) self._clear_cache() return self
@property def is_ragged(self): return not all(self.grid[0].size() == proj.size() for proj in self.grid) def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): if last_dim_is_batch and not self.interpolation_mode: raise ValueError("last_dim_is_batch is only valid with interpolation model") grid = self.grid if self.is_ragged: # Pad the grid - so that grid is the same size for each dimension max_grid_size = max(proj.size(-1) for proj in grid) padded_grid = [] for proj in grid: padding_size = max_grid_size - proj.size(-1) if padding_size > 0: dtype = proj.dtype device = proj.device padded_grid.append( torch.cat([proj, torch.zeros(*proj.shape[:-1], padding_size, dtype=dtype, device=device)]) ) else: padded_grid.append(proj) else: padded_grid = grid if not self.interpolation_mode: if len(x1.shape[:-2]): full_grid = self.full_grid.expand(*x1.shape[:-2], *self.full_grid.shape[-2:]) else: full_grid = self.full_grid if self.interpolation_mode or (torch.equal(x1, full_grid) and torch.equal(x2, full_grid)): if not self.training and hasattr(self, "_cached_kernel_mat"): return self._cached_kernel_mat # Can exploit Toeplitz structure if grid points in each dimension are equally # spaced and using a translation-invariant kernel if settings.use_toeplitz.on(): # Use padded grid for batch mode first_grid_point = torch.stack([proj[0].unsqueeze(0) for proj in grid], dim=-1) full_grid = torch.stack(padded_grid, dim=-1) with warnings.catch_warnings(): # Hide the GPyTorch 2.0 deprecation warning warnings.simplefilter("ignore", DeprecationWarning) covars = to_dense(self.base_kernel(first_grid_point, full_grid, last_dim_is_batch=True, **params)) if last_dim_is_batch: # Toeplitz expects batches of columns so we concatenate the # 1 x grid_size[i] tensors together # Note that this requires all the dimensions to have the same number of grid points covar = ToeplitzLinearOperator(covars.squeeze(-2)) else: # Non-batched ToeplitzLinearOperator expects a 1D tensor, so we squeeze out the row dimension covars = covars.squeeze(-2) # Get rid of the dimension corresponding to the first point # Un-pad the grid covars = [ToeplitzLinearOperator(covars[..., i, : proj.size(-1)]) for i, proj in enumerate(grid)] # Due to legacy reasons, KroneckerProductLinearOperator(A, B, C) is actually (C Kron B Kron A) covar = KroneckerProductLinearOperator(*covars[::-1]) else: full_grid = torch.stack(padded_grid, dim=-1) with warnings.catch_warnings(): # Hide the GPyTorch 2.0 deprecation warning warnings.simplefilter("ignore", DeprecationWarning) covars = to_dense(self.base_kernel(full_grid, full_grid, last_dim_is_batch=True, **params)) if last_dim_is_batch: # Note that this requires all the dimensions to have the same number of grid points covar = covars else: covars = [covars[..., i, : proj.size(-1), : proj.size(-1)] for i, proj in enumerate(self.grid)] covar = KroneckerProductLinearOperator(*covars[::-1]) if not self.training: self._cached_kernel_mat = covar return covar else: return self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params) def num_outputs_per_input(self, x1, x2): return self.base_kernel.num_outputs_per_input(x1, x2)