#!/usr/bin/env python3
from __future__ import annotations
import warnings
from collections.abc import Iterable
from copy import deepcopy
import torch
from torch import Tensor
from gpytorch.distributions import Distribution
from .. import settings
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from ..utils.warnings import GPInputWarning
from .exact_prediction_strategies import prediction_strategy
from .gp import GP
[docs]class ExactGP(GP):
r"""
The base class for any Gaussian process latent function to be used in conjunction
with exact inference.
:param torch.Tensor train_inputs: (size n x d) The training features :math:`\mathbf X`.
:param torch.Tensor train_targets: (size n) The training targets :math:`\mathbf y`.
:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood that defines
the observational distribution. Since we're using exact inference, the likelihood must be Gaussian.
The :meth:`forward` function should describe how to compute the prior latent distribution
on a given input. Typically, this will involve a mean and kernel function.
The result must be a :obj:`~gpytorch.distributions.MultivariateNormal`.
Calling this model will return the posterior of the latent Gaussian process when conditioned
on the training data. The output will be a :obj:`~gpytorch.distributions.MultivariateNormal`.
Example:
>>> class MyGP(gpytorch.models.ExactGP):
>>> def __init__(self, train_x, train_y, likelihood):
>>> super().__init__(train_x, train_y, likelihood)
>>> self.mean_module = gpytorch.means.ZeroMean()
>>> self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
>>>
>>> def forward(self, x):
>>> mean = self.mean_module(x)
>>> covar = self.covar_module(x)
>>> return gpytorch.distributions.MultivariateNormal(mean, covar)
>>>
>>> # train_x = ...; train_y = ...
>>> likelihood = gpytorch.likelihoods.GaussianLikelihood()
>>> model = MyGP(train_x, train_y, likelihood)
>>>
>>> # test_x = ...;
>>> model(test_x) # Returns the GP latent function at test_x
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
"""
def __init__(
self,
train_inputs: Tensor | Iterable[Tensor] | None,
train_targets: Tensor | None,
likelihood: _GaussianLikelihoodBase,
):
if train_inputs is not None and isinstance(train_inputs, Tensor):
train_inputs = (train_inputs,)
if train_inputs is not None and not all(isinstance(train_input, Tensor) for train_input in train_inputs):
raise RuntimeError("Train inputs must be a tensor, or a list/tuple of tensors")
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("ExactGP can only handle Gaussian likelihoods")
super().__init__()
if train_inputs is not None:
self.train_inputs = tuple(tri.unsqueeze(-1) if tri.ndimension() == 1 else tri for tri in train_inputs)
self.train_targets = train_targets
else:
self.train_inputs = None
self.train_targets = None
self.likelihood = likelihood
self.prediction_strategy = None
@property
def train_targets(self) -> tuple[Tensor] | None:
return self._train_targets
@train_targets.setter
def train_targets(self, value: Tensor | None) -> None:
object.__setattr__(self, "_train_targets", value)
def _apply(self, fn):
if self.train_inputs is not None:
self.train_inputs = tuple(fn(train_input) for train_input in self.train_inputs)
self.train_targets = fn(self.train_targets)
return super()._apply(fn)
def _clear_cache(self) -> None:
# The precomputed caches from test time live in prediction_strategy
self.prediction_strategy = None
[docs] def local_load_samples(self, samples_dict, memo, prefix):
"""
Replace the model's learned hyperparameters with samples from a posterior distribution.
"""
# Pyro always puts the samples in the first batch dimension
num_samples = next(iter(samples_dict.values())).size(0)
self.train_inputs = tuple(tri.unsqueeze(0).expand(num_samples, *tri.shape) for tri in self.train_inputs)
self.train_targets = self.train_targets.unsqueeze(0).expand(num_samples, *self.train_targets.shape)
super().local_load_samples(samples_dict, memo, prefix)
[docs] def set_train_data(
self, inputs: Tensor | Iterable[Tensor] | None = None, targets: Tensor | None = None, strict: bool = True
) -> None:
"""
Set training data (does not re-fit model hyper-parameters).
:param inputs: The new training inputs.
:param targets: The new training targets.
:param strict: If `True`, the new inputs and targets must have the same shape,
dtype, and device as the current inputs and targets. Otherwise, any
shape/dtype/device are allowed.
"""
if inputs is not None:
if isinstance(inputs, Tensor):
inputs = (inputs,)
inputs = tuple(input_.unsqueeze(-1) if input_.ndimension() == 1 else input_ for input_ in inputs)
if strict:
for input_, t_input in zip(inputs, self.train_inputs or (None,), strict=True):
for attr in {"shape", "dtype", "device"}:
expected_attr = getattr(t_input, attr, None)
found_attr = getattr(input_, attr, None)
if expected_attr != found_attr:
msg = "Cannot modify {attr} of inputs (expected {e_attr}, found {f_attr})."
msg = msg.format(attr=attr, e_attr=expected_attr, f_attr=found_attr)
raise RuntimeError(msg)
self.train_inputs = inputs
if targets is not None:
if strict:
for attr in {"shape", "dtype", "device"}:
expected_attr = getattr(self.train_targets, attr, None)
found_attr = getattr(targets, attr, None)
if expected_attr != found_attr:
msg = "Cannot modify {attr} of targets (expected {e_attr}, found {f_attr})."
msg = msg.format(attr=attr, e_attr=expected_attr, f_attr=found_attr)
raise RuntimeError(msg)
self.train_targets = targets
self.prediction_strategy = None
[docs] def get_fantasy_model(self, inputs, targets, **kwargs):
"""
Returns a new GP model that incorporates the specified inputs and targets as new training data.
Using this method is more efficient than updating with `set_train_data` when the number of inputs is relatively
small, because any computed test-time caches will be updated in linear time rather than computed from scratch.
.. note::
If `targets` is a batch (e.g. `b x m`), then the GP returned from this method will be a batch mode GP.
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
are the same for each target batch.
:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
observations.
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
:return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added
and all test-time caches have been updated.
:rtype: ~gpytorch.models.ExactGP
"""
if self.prediction_strategy is None:
raise RuntimeError(
"Fantasy observations can only be added after making predictions with a model so that "
"all test independent caches exist. Call the model on some data first!"
)
model_batch_shape = self.train_inputs[0].shape[:-2]
if not isinstance(inputs, list):
inputs = [inputs]
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs]
if not isinstance(self.prediction_strategy.train_prior_dist, MultitaskMultivariateNormal):
data_dim_start = -1
else:
data_dim_start = -2
target_batch_shape = targets.shape[:data_dim_start]
input_batch_shape = inputs[0].shape[:-2]
tbdim, ibdim = len(target_batch_shape), len(input_batch_shape)
if not (tbdim == ibdim + 1 or tbdim == ibdim):
raise RuntimeError(
f"Unsupported batch shapes: The target batch shape ({target_batch_shape}) must have either the "
f"same dimension as or one more dimension than the input batch shape ({input_batch_shape})"
)
# Check whether we can properly broadcast batch dimensions
try:
torch.broadcast_shapes(model_batch_shape, target_batch_shape)
except RuntimeError:
raise RuntimeError(
f"Model batch shape ({model_batch_shape}) and target batch shape "
f"({target_batch_shape}) are not broadcastable."
)
if len(model_batch_shape) > len(input_batch_shape):
input_batch_shape = model_batch_shape
if len(model_batch_shape) > len(target_batch_shape):
target_batch_shape = model_batch_shape
# If input has no fantasy batch dimension but target does, we can save memory and computation by not
# computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
# size of the fantasy model here - this is done below, after the evaluation and fast fantasy update
train_inputs = [tin.expand(input_batch_shape + tin.shape[-2:]) for tin in self.train_inputs]
train_targets = self.train_targets.expand(target_batch_shape + self.train_targets.shape[data_dim_start:])
full_inputs = [
torch.cat(
[train_input, input.expand(input_batch_shape + input.shape[-2:])],
dim=-2,
)
for train_input, input in zip(train_inputs, inputs, strict=True)
]
full_targets = torch.cat(
[train_targets, targets.expand(target_batch_shape + targets.shape[data_dim_start:])], dim=data_dim_start
)
try:
fantasy_kwargs = {"noise": kwargs.pop("noise")}
except KeyError:
fantasy_kwargs = {}
full_output = super().__call__(*full_inputs, **kwargs)
# Copy model without copying training data or prediction strategy (since we'll overwrite those)
old_pred_strat = self.prediction_strategy
old_train_inputs = self.train_inputs
old_train_targets = self.train_targets
old_likelihood = self.likelihood
self.prediction_strategy = None
self.train_inputs = None
self.train_targets = None
self.likelihood = None
new_model = deepcopy(self)
self.prediction_strategy = old_pred_strat
self.train_inputs = old_train_inputs
self.train_targets = old_train_targets
self.likelihood = old_likelihood
new_model.likelihood = old_likelihood.get_fantasy_likelihood(**fantasy_kwargs)
new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
inputs, targets, full_inputs, full_targets, full_output, **fantasy_kwargs
)
# if the fantasies are at the same points, we need to expand the inputs for the new model
if tbdim == ibdim + 1:
new_model.train_inputs = [fi.expand(target_batch_shape + fi.shape[-2:]) for fi in full_inputs]
else:
new_model.train_inputs = full_inputs
new_model.train_targets = full_targets
return new_model
def __call__(self, *args, **kwargs):
train_inputs = list(self.train_inputs) if self.train_inputs is not None else []
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args]
# Training mode: optimizing
if self.training:
if self.train_inputs is None:
raise RuntimeError(
"train_inputs cannot be None in training mode. "
"Call .eval() for prior predictions, or call .set_train_data() to add training data."
)
if settings.debug.on():
if not all(
torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs, strict=True)
):
raise RuntimeError("You must train on the training inputs!")
res = super().__call__(*inputs, **kwargs)
return res
# Prior mode
elif settings.prior_mode.on() or self.train_inputs is None or self.train_targets is None:
full_inputs = args
full_output = super().__call__(*full_inputs, **kwargs)
if settings.debug.on():
if not isinstance(full_output, MultivariateNormal):
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
return full_output
# Posterior mode
else:
if settings.debug.on():
if all(
torch.equal(train_input, input) for train_input, input in zip(train_inputs, inputs, strict=True)
):
warnings.warn(
"The input matches the stored training data. Did you forget to call model.train()?",
GPInputWarning,
)
# Get the terms that only depend on training data
if self.prediction_strategy is None:
train_output = self._get_train_prior_distribution(train_inputs, **kwargs)
# Create the prediction strategy for
self.prediction_strategy = prediction_strategy(
train_inputs=train_inputs,
train_prior_dist=train_output,
train_labels=self.train_targets,
likelihood=self.likelihood,
)
(
test_mean,
test_test_covar,
test_train_covar,
batch_shape,
test_shape,
posterior_class,
) = self._get_test_prior_mean_and_covariances(train_inputs=train_inputs, test_inputs=inputs, **kwargs)
# Make the prediction
with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction(
test_mean=test_mean,
test_test_covar=test_test_covar,
test_train_covar=test_train_covar,
)
# Reshape predictive mean to match the appropriate event shape
predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
return posterior_class(predictive_mean, predictive_covar)
def _get_train_prior_distribution(
self,
train_inputs: Iterable[Tensor],
**kwargs,
) -> MultivariateNormal:
"""Computes the prior distribution on the training set.
Override this method to customize train-train covariance computation.
Args:
train_inputs: The inputs in the training set.
kwargs: Additional keyword arguments passed to the model's forward method.
Returns:
The prior distribution evaluated on the training set.
"""
# No prior_mode context needed: super().__call__() bypasses ExactGP.__call__
# and goes directly to Module.__call__() -> forward(), which computes the prior.
return super().__call__(*train_inputs, **kwargs)
def _get_test_prior_mean_and_covariances(
self,
train_inputs: Iterable[Tensor],
test_inputs: Iterable[Tensor],
**kwargs,
) -> tuple[Tensor, Tensor, Tensor, torch.Size, torch.Size, type[Distribution]]:
"""Computes the prior mean and covariances on the test set.
Override this method to customize test-set covariance computations, e.g.,
for models with partial observations or per-component additive inference.
The returned covariances may have additional leading batch dimensions
(e.g., for additive component-wise inference). The prediction strategy
handles broadcasting with the train-train covariance.
Note: This method is efficient even when test_inputs overlaps with
train_inputs. Slicing the lazy joint covariance only evaluates
K(test, [train||test]); K(train, train) is never computed.
Args:
train_inputs: The training inputs.
test_inputs: The test inputs.
kwargs: Additional keyword arguments passed to the model's forward.
Returns:
A tuple of (test_mean, test_test_covar, test_train_covar, batch_shape,
test_shape, posterior_class).
"""
# Concatenate the input to the training input
full_inputs = []
batch_shape = train_inputs[0].shape[:-2]
for train_input, input in zip(train_inputs, test_inputs, strict=True):
# Make sure the batch shapes agree for training/test data
if batch_shape != train_input.shape[:-2]:
batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
if batch_shape != input.shape[:-2]:
batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
input = input.expand(*batch_shape, *input.shape[-2:])
full_inputs.append(torch.cat([train_input, input], dim=-2))
# Get joint distribution (lazy when settings.lazily_evaluate_kernels is True)
full_output = super().__call__(*full_inputs, **kwargs)
if settings.debug().on():
if not isinstance(full_output, MultivariateNormal):
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
joint_mean, joint_covar = full_output.loc, full_output.lazy_covariance_matrix
# Determine the shape of the joint distribution
batch_shape = full_output.batch_shape
joint_shape = full_output.event_shape
# For single-task GPs: event_shape = (num_points,), so tasks_shape = ()
# For multitask GPs: event_shape = (num_points, num_tasks), so tasks_shape = (num_tasks,)
# This captures any task dimensions beyond the primary data dimension.
tasks_shape = joint_shape[1:]
# Compute test_shape: the event shape for test predictions.
# For single-task GPs: test_shape = (num_test,)
# For multitask GPs: test_shape = (num_test, num_tasks)
num_test = joint_shape[0] - self.prediction_strategy.train_shape[0]
test_shape = torch.Size([num_test, *tasks_shape])
# Find the components of the distribution that contain test data
num_train = self.prediction_strategy.num_train
test_mean = joint_mean[..., num_train:]
# Extract test covariances. Slicing is lazy; K(train, train) is never computed.
# evaluate_kernel() converts to the linear operator type needed by prediction.
# NOTE: We must slice row and column indices together (not sequentially) for
# compatibility with BlockInterleavedLinearOperator used in multitask GPs.
test_test_covar = joint_covar[..., num_train:, num_train:].evaluate_kernel()
test_train_covar = joint_covar[..., num_train:, :num_train].evaluate_kernel()
posterior_class = full_output.__class__
return (test_mean, test_test_covar, test_train_covar, batch_shape, test_shape, posterior_class)