#!/usr/bin/env python3
import math
from typing import Optional
import torch
from linear_operator.operators import LowRankRootLinearOperator, MatmulLinearOperator, RootLinearOperator
from torch import Tensor
from ..models import exact_prediction_strategies
from .kernel import Kernel
[docs]class RFFKernel(Kernel):
r"""
Computes a covariance matrix based on Random Fourier Features with the RBFKernel.
Random Fourier features was originally proposed in
'Random Features for Large-Scale Kernel Machines' by Rahimi and Recht (2008).
Instead of the shifted cosine features from Rahimi and Recht (2008), we use
the sine and cosine features which is a lower-variance estimator --- see
'On the Error of Random Fourier Features' by Sutherland and Schneider (2015).
By Bochner's theorem, any continuous kernel :math:`k` is positive definite
if and only if it is the Fourier transform of a non-negative measure :math:`p(\omega)`, i.e.
.. math::
\begin{equation}
k(x, x') = k(x - x') = \int p(\omega) e^{i(\omega^\top (x - x'))} d\omega.
\end{equation}
where :math:`p(\omega)` is a normalized probability measure if :math:`k(0)=1`.
For the RBF kernel,
.. math::
\begin{equation}
k(\Delta) = \exp{(-\frac{\Delta^2}{2\sigma^2})} \text{ and } p(\omega) = \exp{(-\frac{\sigma^2\omega^2}{2})}
\end{equation}
where :math:`\Delta = x - x'`.
Given datapoint :math:`x\in \mathbb{R}^d`, we can construct its random Fourier features
:math:`z(x) \in \mathbb{R}^{2D}` by
.. math::
\begin{equation}
z(x) = \sqrt{\frac{1}{D}}
\begin{bmatrix}
\cos(\omega_1^\top x)\\
\sin(\omega_1^\top x)\\
\cdots \\
\cos(\omega_D^\top x)\\
\sin(\omega_D^\top x)
\end{bmatrix}, \omega_1, \ldots, \omega_D \sim p(\omega)
\end{equation}
such that we have an unbiased Monte Carlo estimator
.. math::
\begin{equation}
k(x, x') = k(x - x') \approx z(x)^\top z(x') = \frac{1}{D}\sum_{i=1}^D \cos(\omega_i^\top (x - x')).
\end{equation}
.. note::
When this kernel is used in batch mode, the random frequencies are drawn
independently across the batch dimension as well by default.
:param num_samples: Number of random frequencies to draw. This is :math:`D` in the above
papers. This will produce :math:`D` sine features and :math:`D` cosine
features for a total of :math:`2D` random Fourier features.
:type num_samples: int
:param num_dims: (Default `None`.) Dimensionality of the data space.
This is :math:`d` in the above papers. Note that if you want an
independent lengthscale for each dimension, set `ard_num_dims` equal to
`num_dims`. If unspecified, it will be inferred the first time `forward`
is called.
:type num_dims: int, optional
:var torch.Tensor randn_weights: The random frequencies that are drawn once and then fixed.
Example:
>>> # This will infer `num_dims` automatically
>>> kernel= gpytorch.kernels.RFFKernel(num_samples=5)
>>> x = torch.randn(10, 3)
>>> kxx = kernel(x, x).to_dense()
>>> print(kxx.randn_weights.size())
torch.Size([3, 5])
"""
has_lengthscale = True
def __init__(self, num_samples: int, num_dims: Optional[int] = None, **kwargs):
super().__init__(**kwargs)
self.num_samples = num_samples
if num_dims is not None:
self._init_weights(num_dims, num_samples)
def _init_weights(
self, num_dims: Optional[int] = None, num_samples: Optional[int] = None, randn_weights: Optional[Tensor] = None
):
if num_dims is not None and num_samples is not None:
d = num_dims
D = num_samples
if randn_weights is None:
randn_shape = torch.Size([*self._batch_shape, d, D])
randn_weights = torch.randn(
randn_shape, dtype=self.raw_lengthscale.dtype, device=self.raw_lengthscale.device
)
self.register_buffer("randn_weights", randn_weights)
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **kwargs) -> Tensor:
if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)
num_dims = x1.size(-1)
if not hasattr(self, "randn_weights"):
self._init_weights(num_dims, self.num_samples)
x1_eq_x2 = torch.equal(x1, x2)
z1 = self._featurize(x1, normalize=False)
if not x1_eq_x2:
z2 = self._featurize(x2, normalize=False)
else:
z2 = z1
D = float(self.num_samples)
if diag:
return (z1 * z2).sum(-1) / D
if x1_eq_x2:
# Exploit low rank structure, if there are fewer features than data points
if z1.size(-1) < z2.size(-2):
return LowRankRootLinearOperator(z1 / math.sqrt(D))
else:
return RootLinearOperator(z1 / math.sqrt(D))
else:
return MatmulLinearOperator(z1 / D, z2.transpose(-1, -2))
def _featurize(self, x: Tensor, normalize: bool = False) -> Tensor:
# Recompute division each time to allow backprop through lengthscale
# Transpose lengthscale to allow for ARD
x = x.matmul(self.randn_weights / self.lengthscale.transpose(-1, -2))
z = torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
if normalize:
D = self.num_samples
z = z / math.sqrt(D)
return z
def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood):
# Allow for fast sampling
return exact_prediction_strategies.RFFPredictionStrategy(
train_inputs, train_prior_dist, train_labels, likelihood
)