#!/usr/bin/env python3
from math import pi
from typing import Callable, Optional
import torch
from ..constraints import Interval, Positive
from ..priors import Prior
from .kernel import Kernel
[docs]class ArcKernel(Kernel):
r""" Computes a covariance matrix based on the Arc Kernel
(https://arxiv.org/abs/1409.4011) between inputs :math:`\mathbf{x_1}`
and :math:`\mathbf{x_2}`. First it applies a cylindrical embedding:
.. math::
g_{i}(\mathbf{x}) = \begin{cases}
[0, 0]^{T} & \delta_{i}(\mathbf{x}) = \text{false}\\
\omega_{i} \left[ \sin{\pi\rho_{i}\frac{x_{i}}{u_{i}-l_{i}}},
\cos{\pi\rho_{i}\frac{x_{i}}{u_{i}-l_{i}}} \right] & \text{otherwise}
\end{cases}
where
* :math:`\rho` is the angle parameter.
* :math:`\omega` is a radius parameter.
then the kernel is built with the particular covariance function, e.g.
.. math::
\begin{equation}
k_{i}(\mathbf{x}, \mathbf{x'}) =
\sigma^{2}\exp \left(-\frac{1}{2}d_{i}(\mathbf{x}, \mathbf{x^{'}}) \right)^{2}
\end{equation}
and the produt between dimensions
.. math::
\begin{equation}
k_{i}(\mathbf{x}, \mathbf{x'}) =
\sigma^{2}\exp \left(-\frac{1}{2}d_{i}(\mathbf{x}, \mathbf{x^{'}}) \right)^{2}
\end{equation}
.. note::
This kernel does not have an `outputscale` parameter. To add a scaling
parameter, decorate this kernel with a
:class:`gpytorch.kernels.ScaleKernel`.
When using with an input of `b x n x d` dimensions, decorate this
kernel with :class:`gpytorch.kernel.ProductStructuredKernel , setting
the number of dims, `num_dims to d.`
.. note::
This kernel does not have an ARD lengthscale option.
:param base_kernel: (Default :obj:`gpytorch.kernels.MaternKernel(nu=2.5)`.)
The euclidean covariance of choice.
:type base_kernel: :obj:`~gpytorch.kernels.Kernel`
:param ard_num_dims: (Default `None`.) The number of dimensions to compute the kernel for.
The kernel has two parameters which are individually defined for each
dimension, defaults to None
:type ard_num_dims: int, optional
:param angle_prior: Set this if you want to apply a prior to the period angle parameter.
:type angle_prior: :obj:`~gpytorch.priors.Prior`, optional
:param radius_prior: Set this if you want to apply a prior to the lengthscale parameter.
:type radius_prior: :obj:`~gpytorch.priors.Prior`, optional
:var torch.Tensor radius: The radius parameter. Size = `*batch_shape x 1`.
:var torch.Tensor angle: The period angle parameter. Size = `*batch_shape x 1`.
Example:
>>> x = torch.randn(10, 5)
>>> # Non-batch: Simple option
... base_kernel = gpytorch.kernels.MaternKernel(nu=2.5)
>>> base_kernel.raw_lengthscale.requires_grad_(False)
>>> covar_module = gpytorch.kernels.ProductStructureKernel(
gpytorch.kernels.ScaleKernel(
ArcKernel(base_kernel,
angle_prior=gpytorch.priors.GammaPrior(0.5,1),
radius_prior=gpytorch.priors.GammaPrior(3,2),
ard_num_dims=x.shape[-1])),
num_dims=x.shape[-1])
>>> covar = covar_module(x)
>>> print(covar.shape)
>>> # Now with batch
>>> covar_module = gpytorch.kernels.ProductStructureKernel(
gpytorch.kernels.ScaleKernel(
ArcKernel(base_kernel,
angle_prior=gpytorch.priors.GammaPrior(0.5,1),
radius_prior=gpytorch.priors.GammaPrior(3,2),
ard_num_dims=x.shape[-1])),
num_dims=x.shape[-1])
>>> covar = covar_module(x
>>> print(covar.shape)
"""
has_lengthscale = True
def __init__(
self,
base_kernel: Kernel,
delta_func: Optional[Callable] = None,
angle_prior: Optional[Prior] = None,
radius_prior: Optional[Prior] = None,
**kwargs,
):
super(ArcKernel, self).__init__(has_lengthscale=True, **kwargs)
if self.ard_num_dims is None:
self.last_dim = 1
else:
self.last_dim = self.ard_num_dims
if delta_func is None:
self.delta_func = self.default_delta_func
else:
self.delta_func = delta_func
# TODO: check the errors given by interval
angle_constraint = Interval(0.1, 0.9)
self.register_parameter(
name="raw_angle",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, self.last_dim)),
)
if angle_prior is not None:
if not isinstance(angle_prior, Prior):
raise TypeError("Expected gpytorch.priors.Prior but got " + type(angle_prior).__name__)
self.register_prior(
"angle_prior",
angle_prior,
lambda m: m.angle,
lambda m, v: m._set_angle(v),
)
self.register_constraint("raw_angle", angle_constraint)
self.register_parameter(
name="raw_radius",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, self.last_dim)),
)
if radius_prior is not None:
if not isinstance(radius_prior, Prior):
raise TypeError("Expected gpytorch.priors.Prior but got " + type(radius_prior).__name__)
self.register_prior(
"radius_prior",
radius_prior,
lambda m: m.radius,
lambda m, v: m._set_radius(v),
)
radius_constraint = Positive()
self.register_constraint("raw_radius", radius_constraint)
self.base_kernel = base_kernel
if self.base_kernel.has_lengthscale:
self.base_kernel.lengthscale = 1
self.base_kernel.raw_lengthscale.requires_grad_(False)
@property
def angle(self):
return self.raw_angle_constraint.transform(self.raw_angle)
@angle.setter
def angle(self, value):
self._set_angle(value)
def _set_angle(self, value):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_angle)
self.initialize(raw_angle=self.raw_angle_constraint.inverse_transform(value))
@property
def radius(self):
return self.raw_radius_constraint.transform(self.raw_radius)
@radius.setter
def radius(self, value):
self._set_radius(value)
def _set_radius(self, value):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_radius)
self.initialize(raw_radius=self.raw_radius_constraint.inverse_transform(value))
def embedding(self, x):
mask = self.delta_func(x)
x_ = x.div(self.lengthscale)
x_s = self.radius * torch.sin(pi * self.angle * x_) * mask
x_c = self.radius * torch.cos(pi * self.angle * x_) * mask
x_ = torch.cat((x_s, x_c), dim=-1)
return x_
def default_delta_func(self, x):
return torch.ones_like(x)
def forward(self, x1, x2, diag=False, **params):
x1_, x2_ = self.embedding(x1), self.embedding(x2)
return self.base_kernel(x1_, x2_, diag=diag)