[1]:
# smoke_test = True
import gpytorch
import torch
from torch.utils import benchmark
Kernels with Additive or Product Structure¶
One of the most powerful properties of kernels is their closure under various composition operation. Many important covariance functions can be written as the sum or the product of \(m\) component kernels:
Additive and product kernels are used for a variety of reasons. 1. They are often more interpretable, as argued in Duvenaud et al. (2011). 2. They can be extremely powerful and expressive, as demonstrated by Wilson and Adams (2013). 3. They can be extremely sample efficient for Bayesian optimization, as demonstrated by Kandasamy et al. (2015) and Gardner et al. (2017).
We will discuss various ways to perform additive and product compositions of kernels in GPyTorch. The simplest mechanism is to add/multiply the kernel objects together, or add/multiply their outputs. However, there are more complex but far more efficient ways for adding/multiplying kernels with similar functional forms, which will enable significant parallelism especially on GPUs.
Simple Sums and Products¶
As an example, consider the spectral mixture kernel with two components on a univariate input. If we remove the scaling components, it can be implemented as:
where \(\ell_1, \ell_2, \omega_1, \omega_2\) are hyperparameters. We can naively implement this kernel in two ways…
[2]:
# Toy data
X = torch.randn(10, 1)
# Base kernels
rbf_kernel_1 = gpytorch.kernels.RBFKernel()
cos_kernel_1 = gpytorch.kernels.CosineKernel()
rbf_kernel_2 = gpytorch.kernels.RBFKernel()
cos_kernel_2 = gpytorch.kernels.CosineKernel()
# Implementation 1:
spectral_mixture_kernel = (rbf_kernel_1 * cos_kernel_1) + (rbf_kernel_2 * cos_kernel_2)
covar = spectral_mixture_kernel(X)
# Implementation 2:
covar = rbf_kernel_1(X) * cos_kernel_1(X) + rbf_kernel_2(X) * cos_kernel_2(X)
Implementation 1 constructs a spectral_mixture_kernel
object by applying +
and *
directly to the component kernel objects. Implementation 2 constructrs the resulting covariance matrix by applying +
and *
to the outputs of the component kernels. Both implementations are equivalent (the spectral_mixture_kernel
object created by Implementation 1 essentially performs Implementation 2) under the hood.
(Of course, neither implementation should be used in practice for the spectral mixture kernel. The built-in SpectralMixtureKernel class is far more efficient.)
Efficient Parallel Implementations of Additive Structure or Product Structure Kernels¶
Above we considered the sum and products of kernels with different functional forms. However, often we are considering the sum/product over kernels with The above example is simple to read, but quite slow in practice. Under the hood, each of the kernels (and their compositions) are computed sequentially. GPyTorch will compute the first cosine kernel, followed by the first RBF kernel, followed by their product, and so on.
When the component kernels have the same function form, we can get massive efficieny gains by exploiting parallelism. We combine all of the component kernels into a batch kernel so that each component kernel can be computed simultaneously. We then compute the sum
or prod
over the batch dimension. This strategy will yield significant speedups especially on the GPU.
Example #1: Efficient Summations of Univariate Kernels¶
As an example, let’s assume that we have \(d\)-dimensional input data \(\boldsymbol x, \boldsymbol x' \in \mathbb R^d\). We can define an additive kernel that is the sum of \(d\) univariate RBF kernels, each of which acts on a single dimension of \(\boldsymbol x\) and \(\boldsymbol x'\).
Here, \(\ell^{(i)}\) is the lengthscale associated with dimension \(i\). Note that we are using a different lengthscale for each of the component kernels. Nevertheless, we can efficiently compute each of the component kernels in parallel using batching. First we define a RBFKernel object designed to compute a batch of :math:`d` univariate kernels:
[3]:
d = 3
batch_univariate_rbf_kernel = gpytorch.kernels.RBFKernel(
batch_shape=torch.Size([d]), # A batch of d...
ard_num_dims=1, # ...univariate kernels
)
Including the batch_shape
argument ensures that the lengthscale
parameter of the batch_univariate_rbf_kernel
is a d x 1 x 1
tensor; i.e. each univariate kernel will have its own lengthscale. (We could instead have each univariate kernel share the same lengthscale by omitting the batch_shape
argument.)
To compute the univariate kernel matrices, we need to feed the appropriate dimensions of \(\boldsymbol X\) into each of the component kernels. We accomplish this by reshaping the n x d
matrix representing \(\boldsymbol X\) into a batch of \(d\) n x 1
matrices (i.e. a d x n x 1
tensor).
[4]:
n = 10
X = torch.randn(n, d) # Some random data in a n x d matrix
batched_dimensions_of_X = X.mT.unsqueeze(-1) # Now a d x n x 1 tensor
We then feed the batches of univariate data into the batched kernel object to get our batch of univariate kernel matrices:
[5]:
univariate_rbf_covars = batch_univariate_rbf_kernel(batched_dimensions_of_X)
univariate_rbf_covars.shape # d x n x n
[5]:
torch.Size([3, 10, 10])
And finally, to get the multivariate kernel, we can compute the sum over the batch (i.e. the sum over the univariate kernels)
[6]:
additive_covar = univariate_rbf_covars.sum(dim=-3) # Computes the sum over the batch dimension
additive_covar.shape # n x n
[6]:
torch.Size([10, 10])
On a small dataset, this approach is comparable to the naive approach described above. However, it will become much faster on a larger and more high dimensional dataset, especially on the GPU.
[7]:
d = 10
n = 500
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
X = torch.randn(n, d, device=device)
naive_additive_kernel = (
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[0]) +
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[1]) +
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[2]) +
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[3]) +
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[4]) +
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[5]) +
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[6]) +
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[7]) +
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[8]) +
gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[9])
).to(device=device)
with gpytorch.settings.lazily_evaluate_kernels(False):
print(benchmark.Timer(
stmt="naive_additive_kernel(X)",
globals={"naive_additive_kernel": naive_additive_kernel, "X": X}
).timeit(100))
<torch.utils.benchmark.utils.common.Measurement object at 0x7f35cc074640>
naive_additive_kernel(X)
3.37 ms
1 measurement, 100 runs , 1 thread
[8]:
batch_univariate_rbf_kernel = gpytorch.kernels.RBFKernel(
batch_shape=torch.Size([d]), ard_num_dims=1,
).to(device=device)
with gpytorch.settings.lazily_evaluate_kernels(False):
print(benchmark.Timer(
stmt="batch_univariate_rbf_kernel(X.mT.unsqueeze(-1)).sum(dim=-3)",
globals={"batch_univariate_rbf_kernel": batch_univariate_rbf_kernel, "X": X}
).timeit(100))
<torch.utils.benchmark.utils.common.Measurement object at 0x7f35cc076ec0>
batch_univariate_rbf_kernel(X.mT.unsqueeze(-1)).sum(dim=-3)
940.30 us
1 measurement, 100 runs , 1 thread
Full Example¶
Putting it all together, a GP using this efficient additive kernel would look something like…
[9]:
class AdditiveKernelGP(gpytorch.models.ExactGP):
def __init__(self, X_train, y_train, d):
likelihood = gpytorch.likelihoods.GaussianLikelihood()
super().__init__(X_train, y_train, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([d]), ard_num_dims=1)
)
def forward(self, X):
mean = self.mean_module(X)
batched_dimensions_of_X = X.mT.unsqueeze(-1) # Now a d x n x 1 tensor
covar = self.covar_module(batched_dimensions_of_X).sum(dim=-3)
return gpytorch.distributions.MultivariateNormal(mean, covar)
Example #2: Efficient Products of Univariate Kernels¶
As another example, we can consider a multivariate kernel defined as the product of univariate kernels, i.e.:
[10]:
d = 3
n = 10
batch_univariate_rbf_kernel = gpytorch.kernels.RBFKernel(
batch_shape=torch.Size([d]), ard_num_dims=1,
)
X = torch.randn(n, d)
univariate_rbf_covars = batch_univariate_rbf_kernel(X.mT.unsqueeze(-1))
with gpytorch.settings.lazily_evaluate_kernels(False):
prod_covar = univariate_rbf_covars.prod(dim=-3)
prod_covar.shape # n x n
[10]:
torch.Size([10, 10])
This particular example is a bit silly, since the multivariate RBF kernel is exactly equivalent to the product of \(d\) univariate RBF kernels,
However, this strategy can actually become advantageous when we approximate each of the univariate component kernels using a scalable \(\ll \mathcal O(n^3)\) approximation for each of the univariate kernels. See the tutorial on SKIP (structured kernel interpolation of products) for an example of exploiting product structure for scalability.
Summing Higher Order Interactions Between Univariate Kernels (Additive Gaussian Processes)¶
Duvenaud et al. (2011) introduce “Additive Gaussian Processes,” which are GPs that additively compose interaction terms between univariate kernels. For example, with \(d\)-dimensional data and a max-degree of \(3\) interaction terms, the corresponding kernel would be:
Despite the summations having an exponential number of terms, this kernel can be computed in \(\mathcal O(d^2)\) time using the Newton-Girard formula.
To compute this kernel in GPyTorch, we begin with a batch of the univariate covariance matrices (stored in a d x n x n
Tensor or LinearOperator). We follow the same techniques as we used before:
[11]:
d = 4
n = 10
batch_univariate_rbf_kernel = gpytorch.kernels.RBFKernel(
batch_shape=torch.Size([d]), ard_num_dims=1,
)
X = torch.randn(n, d)
with gpytorch.settings.lazily_evaluate_kernels(False):
univariate_rbf_covars = batch_univariate_rbf_kernel(X.mT.unsqueeze(-1))
univariate_rbf_covars.shape # d x n x n
[11]:
torch.Size([4, 10, 10])
We then use the gpytorch.utils.sum_interaction_terms
to compute and sum all of the higher-order interaction terms in \(\mathcal O(d^2)\) time:
[12]:
covar = gpytorch.utils.sum_interaction_terms(univariate_rbf_covars, max_degree=3, dim=-3)
covar.shape # n x n
[12]:
torch.Size([10, 10])
The full GP proposed by Duvenaud et al. (2011) would then look like:
[13]:
class AdditiveGP(gpytorch.models.ExactGP):
def __init__(self, X_train, y_train, d, max_degree):
likelihood = gpytorch.likelihoods.GaussianLikelihood()
super().__init__(X_train, y_train, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(batch_shape=torch.Size([d]), ard_num_dims=1)
)
self.max_degree = max_degree
def forward(self, X):
mean = self.mean_module(X)
batched_dimensions_of_X = X.mT.unsqueeze(-1) # Now a d x n x 1 tensor
univariate_rbf_covars = self.covar_module(batched_dimensions_of_X)
covar = gpytorch.utils.sum_interaction_terms(
univariate_rbf_covars, max_degree=self.max_degree, dim=-3
)
return gpytorch.distributions.MultivariateNormal(mean, covar)
(For those familiar with previous versions of GPyTorch, ``sum_interaction_terms`` replaces what was previously implemented by ``NewtonGirardAdditiveKernel``.)