Source code for gpytorch.kernels.spectral_delta_kernel

import math
from typing import Optional

import torch
from linear_operator.operators import MatmulLinearOperator, RootLinearOperator

from ..constraints import Interval, Positive
from .kernel import Kernel


[docs]class SpectralDeltaKernel(Kernel): """ A kernel that supports spectral learning for GPs, where the underlying spectral density is modeled as a mixture of delta distributions (e.g., with point masses). This has been explored e.g. in Lazaro-Gredilla et al., 2010. Conceptually, this kernel is similar to random Fourier features as implemented in RFFKernel, but instead of sampling a Gaussian to determine the spectrum sites, they are treated as learnable parameters. When using CG for inference, this kernel supports linear space and time (in N) for training and inference. :param int num_dims: Dimensionality of input data that this kernel will operate on. Note that if active_dims is used, this should be the length of the active dim set. :param int num_deltas: Number of point masses to learn. """ has_lengthscale = True def __init__( self, num_dims: int, num_deltas: Optional[int] = 128, Z_constraint: Optional[Interval] = None, batch_shape: Optional[torch.Size] = torch.Size([]), **kwargs, ): Kernel.__init__(self, has_lengthscale=True, batch_shape=batch_shape, **kwargs) self.raw_Z = torch.nn.Parameter(torch.rand(*batch_shape, num_deltas, num_dims)) if Z_constraint: self.register_constraint("raw_Z", Z_constraint) else: self.register_constraint("raw_Z", Positive()) self.num_dims = num_dims
[docs] def initialize_from_data(self, train_x, train_y): """ Initialize the point masses for this kernel from the empirical spectrum of the data. To do this, we estimate the empirical spectrum's CDF and then simply sample from it. This is analogous to how the SM kernel's mixture is initialized, but we skip the last step of fitting a GMM to the samples and just use the samples directly. """ import numpy as np from scipy.fftpack import fft from scipy.integrate import cumtrapz N = train_x.size(-2) emp_spect = np.abs(fft(train_y.cpu().detach().numpy())) ** 2 / N M = math.floor(N / 2) freq1 = np.arange(M + 1) freq2 = np.arange(-M + 1, 0) freq = np.hstack((freq1, freq2)) / N freq = freq[: M + 1] emp_spect = emp_spect[: M + 1] total_area = np.trapz(emp_spect, freq) spec_cdf = np.hstack((np.zeros(1), cumtrapz(emp_spect, freq))) spec_cdf = spec_cdf / total_area a = np.random.rand(self.raw_Z.size(-2), 1) p, q = np.histogram(a, spec_cdf) bins = np.digitize(a, q) slopes = (spec_cdf[bins] - spec_cdf[bins - 1]) / (freq[bins] - freq[bins - 1]) intercepts = spec_cdf[bins - 1] - slopes * freq[bins - 1] inv_spec = (a - intercepts) / slopes self.Z = inv_spec
def initialize_from_data_simple(self, train_x, train_y, **kwargs): if not torch.is_tensor(train_x) or not torch.is_tensor(train_y): raise RuntimeError("train_x and train_y should be tensors") if train_x.ndimension() == 1: train_x = train_x.unsqueeze(-1) if train_x.ndimension() == 2: train_x = train_x.unsqueeze(0) train_x_sort = train_x.sort(1)[0] min_dist_sort = (train_x_sort[:, 1:, :] - train_x_sort[:, :-1, :]).squeeze(0) ard_num_dims = 1 if self.ard_num_dims is None else self.ard_num_dims min_dist = torch.zeros(1, ard_num_dims, dtype=self.Z.dtype, device=self.Z.device) for ind in range(ard_num_dims): min_dist[:, ind] = min_dist_sort[(torch.nonzero(min_dist_sort[:, ind]))[0], ind] z_init = torch.rand_like(self.Z).mul_(0.5).div_(min_dist) self.Z = z_init @property def Z(self): return self.raw_Z_constraint.transform(self.raw_Z) @Z.setter def Z(self, value): self._set_Z(value) def _set_Z(self, value): if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_Z) self.initialize(raw_Z=self.raw_Z_constraint.inverse_transform(value)) def forward(self, x1, x2, diag=False, **params): x1_ = x1.div(self.lengthscale) x2_ = x2.div(self.lengthscale) Z = self.Z # Z1_ and Z2_ are s x d x1z1 = x1_.matmul(Z.transpose(-2, -1)) # n x s x2z2 = x2_.matmul(Z.transpose(-2, -1)) # n x s x1z1 = x1z1 * 2 * math.pi x2z2 = x2z2 * 2 * math.pi x1z1 = torch.cat([x1z1.cos(), x1z1.sin()], dim=-1) / math.sqrt(x1z1.size(-1)) x2z2 = torch.cat([x2z2.cos(), x2z2.sin()], dim=-1) / math.sqrt(x2z2.size(-1)) if x1.size() == x2.size() and torch.equal(x1, x2): prod = RootLinearOperator(x1z1) else: prod = MatmulLinearOperator(x1z1, x2z2.transpose(-2, -1)) if diag: return prod.diagonal(dim1=-1, dim2=-2) else: return prod