#!/usr/bin/env python3
from typing import List, Optional, Tuple, Union
import torch
from linear_operator import to_linear_operator
from linear_operator.operators import InterpolatedLinearOperator
from ..models.exact_prediction_strategies import InterpolatedPredictionStrategy
from ..utils.grid import create_grid
from ..utils.interpolation import Interpolation
from .grid_kernel import GridKernel
from .kernel import Kernel
[docs]class GridInterpolationKernel(GridKernel):
r"""
Implements the KISS-GP (or SKI) approximation for a given kernel.
It was proposed in `Kernel Interpolation for Scalable Structured Gaussian Processes`_,
and offers extremely fast and accurate Kernel approximations for large datasets.
Given a base kernel `k`, the covariance :math:`k(\mathbf{x_1}, \mathbf{x_2})` is approximated by
using a grid of regularly spaced *inducing points*:
.. math::
\begin{equation*}
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U} \mathbf{w_{x_2}}
\end{equation*}
where
* :math:`U` is the set of gridded inducing points
* :math:`K_{U,U}` is the kernel matrix between the inducing points
* :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
:math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
The user should supply the size of the grid (using the grid_size attribute).
To choose a reasonable grid value, we highly recommend using the
:func:`gpytorch.utils.grid.choose_grid_size` helper function.
The bounds of the grid will automatically be determined by data.
(Alternatively, you can hard-code bounds using the grid_bounds, which
will speed up this kernel's computations.)
.. note::
`GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
Periodic, Spectral Mixture, etc.)
Args:
base_kernel (Kernel):
The kernel to approximate with KISS-GP
grid_size (Union[int, List[int]]):
The size of the grid in each dimension.
If a single int is provided, then every dimension will have the same grid size.
num_dims (int):
The dimension of the input data. Required if `grid_bounds=None`
grid_bounds (tuple(float, float), optional):
The bounds of the grid, if known (high performance mode).
The length of the tuple must match the number of dimensions.
The entries represent the min/max values for each dimension.
active_dims (tuple of ints, optional):
Passed down to the `base_kernel`.
.. _Kernel Interpolation for Scalable Structured Gaussian Processes:
http://proceedings.mlr.press/v37/wilson15.pdf
"""
def __init__(
self,
base_kernel: Kernel,
grid_size: Union[int, List[int]],
num_dims: Optional[int] = None,
grid_bounds: Optional[Tuple[float, float]] = None,
active_dims: Optional[Tuple[int, ...]] = None,
):
has_initialized_grid = 0
grid_is_dynamic = True
# Make some temporary grid bounds, if none exist
if grid_bounds is None:
if num_dims is None:
raise RuntimeError("num_dims must be supplied if grid_bounds is None")
else:
# Create some temporary grid bounds - they'll be changed soon
grid_bounds = tuple((-1.0, 1.0) for _ in range(num_dims))
else:
has_initialized_grid = 1
grid_is_dynamic = False
if num_dims is None:
num_dims = len(grid_bounds)
elif num_dims != len(grid_bounds):
raise RuntimeError(
"num_dims ({}) disagrees with the number of supplied "
"grid_bounds ({})".format(num_dims, len(grid_bounds))
)
if isinstance(grid_size, int):
grid_sizes = [grid_size for _ in range(num_dims)]
else:
grid_sizes = list(grid_size)
if len(grid_sizes) != num_dims:
raise RuntimeError("The number of grid sizes provided through grid_size do not match num_dims.")
# Initialize values and the grid
self.grid_is_dynamic = grid_is_dynamic
self.num_dims = num_dims
self.grid_sizes = grid_sizes
self.grid_bounds = grid_bounds
grid = create_grid(self.grid_sizes, self.grid_bounds)
super(GridInterpolationKernel, self).__init__(
base_kernel=base_kernel,
grid=grid,
interpolation_mode=True,
active_dims=active_dims,
)
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))
@property
def _tight_grid_bounds(self):
grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds))
return tuple(
(bound[0] + 2.01 * spacing, bound[1] - 2.01 * spacing)
for bound, spacing in zip(self.grid_bounds, grid_spacings)
)
def _compute_grid(self, inputs, last_dim_is_batch=False):
n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
if last_dim_is_batch:
inputs = inputs.transpose(-1, -2).unsqueeze(-1)
n_dimensions = 1
batch_shape = inputs.shape[:-2]
inputs = inputs.reshape(-1, n_dimensions)
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
interp_indices = interp_indices.view(*batch_shape, n_data, -1)
interp_values = interp_values.view(*batch_shape, n_data, -1)
return interp_indices, interp_values
def _inducing_forward(self, last_dim_is_batch, **params):
return super().forward(self.grid, self.grid, last_dim_is_batch=last_dim_is_batch, **params)
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
# See if we need to update the grid or not
if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in
if torch.equal(x1, x2):
x = x1.reshape(-1, self.num_dims)
else:
x = torch.cat([x1.reshape(-1, self.num_dims), x2.reshape(-1, self.num_dims)])
x_maxs = x.max(0)[0].tolist()
x_mins = x.min(0)[0].tolist()
# We need to update the grid if
# 1) it hasn't ever been initialized, or
# 2) if any of the grid points are "out of bounds"
update_grid = (not self.has_initialized_grid.item()) or any(
x_min < bound[0] or x_max > bound[1]
for x_min, x_max, bound in zip(x_mins, x_maxs, self._tight_grid_bounds)
)
# Update the grid if needed
if update_grid:
grid_spacings = tuple(
(x_max - x_min) / (gs - 4.02) for gs, x_min, x_max in zip(self.grid_sizes, x_mins, x_maxs)
)
self.grid_bounds = tuple(
(x_min - 2.01 * spacing, x_max + 2.01 * spacing)
for x_min, x_max, spacing in zip(x_mins, x_maxs, grid_spacings)
)
grid = create_grid(
self.grid_sizes,
self.grid_bounds,
dtype=self.grid[0].dtype,
device=self.grid[0].device,
)
self.update_grid(grid)
base_lazy_tsr = to_linear_operator(self._inducing_forward(last_dim_is_batch=last_dim_is_batch, **params))
if last_dim_is_batch and base_lazy_tsr.size(-3) == 1:
base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1), 1, 1)
left_interp_indices, left_interp_values = self._compute_grid(x1, last_dim_is_batch)
if torch.equal(x1, x2):
right_interp_indices = left_interp_indices
right_interp_values = left_interp_values
else:
right_interp_indices, right_interp_values = self._compute_grid(x2, last_dim_is_batch)
batch_shape = torch.broadcast_shapes(
base_lazy_tsr.batch_shape,
left_interp_indices.shape[:-2],
right_interp_indices.shape[:-2],
)
res = InterpolatedLinearOperator(
base_lazy_tsr.expand(*batch_shape, *base_lazy_tsr.matrix_shape),
left_interp_indices.detach().expand(*batch_shape, *left_interp_indices.shape[-2:]),
left_interp_values.expand(*batch_shape, *left_interp_values.shape[-2:]),
right_interp_indices.detach().expand(*batch_shape, *right_interp_indices.shape[-2:]),
right_interp_values.expand(*batch_shape, *right_interp_values.shape[-2:]),
)
if diag:
return res.diagonal(dim1=-1, dim2=-2)
else:
return res
def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood):
return InterpolatedPredictionStrategy(train_inputs, train_prior_dist, train_labels, likelihood)