#!/usr/bin/env python3
import torch
from linear_operator.operators import InterpolatedLinearOperator
from linear_operator.utils.interpolation import left_interp
from ..distributions import MultivariateNormal
from ..utils.interpolation import Interpolation
from ..utils.memoize import cached
from ._variational_strategy import _VariationalStrategy
[docs]class GridInterpolationVariationalStrategy(_VariationalStrategy):
r"""
This strategy constrains the inducing points to a grid and applies a deterministic
relationship between :math:`\mathbf f` and :math:`\mathbf u`.
It was introduced by `Wilson et al. (2016)`_.
Here, the inducing points are not learned. Instead, the strategy
automatically creates inducing points based on a set of grid sizes and grid
bounds.
.. _Wilson et al. (2016):
https://arxiv.org/abs/1611.00336
:param ~gpytorch.models.ApproximateGP model: Model this strategy is applied to.
Typically passed in when the VariationalStrategy is created in the
__init__ method of the user defined model.
:param int grid_size: Size of the grid
:param list grid_bounds: Bounds of each dimension of the grid (should be a list of (float, float) tuples)
:param ~gpytorch.variational.VariationalDistribution variational_distribution: A
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
"""
def __init__(self, model, grid_size, grid_bounds, variational_distribution):
grid = torch.zeros(grid_size, len(grid_bounds))
for i in range(len(grid_bounds)):
grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2)
grid[:, i] = torch.linspace(grid_bounds[i][0] - grid_diff, grid_bounds[i][1] + grid_diff, grid_size)
inducing_points = torch.zeros(int(pow(grid_size, len(grid_bounds))), len(grid_bounds))
prev_points = None
for i in range(len(grid_bounds)):
for j in range(grid_size):
inducing_points[j * grid_size**i : (j + 1) * grid_size**i, i].fill_(grid[j, i])
if prev_points is not None:
inducing_points[j * grid_size**i : (j + 1) * grid_size**i, :i].copy_(prev_points)
prev_points = inducing_points[: grid_size ** (i + 1), : (i + 1)]
super(GridInterpolationVariationalStrategy, self).__init__(
model, inducing_points, variational_distribution, learn_inducing_locations=False
)
object.__setattr__(self, "model", model)
self.register_buffer("grid", grid)
def _compute_grid(self, inputs):
n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
batch_shape = inputs.shape[:-2]
inputs = inputs.reshape(-1, n_dimensions)
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
interp_indices = interp_indices.view(*batch_shape, n_data, -1)
interp_values = interp_values.view(*batch_shape, n_data, -1)
if (interp_indices.dim() - 2) != len(self._variational_distribution.batch_shape):
batch_shape = torch.broadcast_shapes(interp_indices.shape[:-2], self._variational_distribution.batch_shape)
interp_indices = interp_indices.expand(*batch_shape, *interp_indices.shape[-2:])
interp_values = interp_values.expand(*batch_shape, *interp_values.shape[-2:])
return interp_indices, interp_values
@property
@cached(name="prior_distribution_memo")
def prior_distribution(self):
out = self.model.forward(self.inducing_points)
# TODO: investigate why smaller than 1e-3 breaks some tests
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(1e-3))
return res
def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None):
if variational_inducing_covar is None:
raise RuntimeError(
"GridInterpolationVariationalStrategy is only compatible with Gaussian variational "
f"distributions. Got ({self.variational_distribution.__class__.__name__}."
)
variational_distribution = self.variational_distribution
# Get interpolations
interp_indices, interp_values = self._compute_grid(x)
# Compute test mean
# Left multiply samples by interpolation matrix
predictive_mean = left_interp(interp_indices, interp_values, inducing_values.unsqueeze(-1))
predictive_mean = predictive_mean.squeeze(-1)
# Compute test covar
predictive_covar = InterpolatedLinearOperator(
variational_distribution.lazy_covariance_matrix,
interp_indices,
interp_values,
interp_indices,
interp_values,
)
output = MultivariateNormal(predictive_mean, predictive_covar)
return output