Source code for gpytorch.kernels.scale_kernel

#!/usr/bin/env python3

from typing import Optional

import torch
from linear_operator.operators import to_dense

from ..constraints import Interval, Positive
from ..priors import Prior
from .kernel import Kernel

[docs]class ScaleKernel(Kernel): r""" Decorates an existing kernel object with an output scale, i.e. .. math:: \begin{equation*} K_{\text{scaled}} = \theta_\text{scale} K_{\text{orig}} \end{equation*} where :math:`\theta_\text{scale}` is the `outputscale` parameter. In batch-mode (i.e. when :math:`x_1` and :math:`x_2` are batches of input matrices), each batch of data can have its own `outputscale` parameter by setting the `batch_shape` keyword argument to the appropriate number of batches. .. note:: The outputscale parameter is parameterized on a log scale to constrain it to be positive. You can set a prior on this parameter using the outputscale_prior argument. Args: base_kernel (Kernel): The base kernel to be scaled. batch_shape (int, optional): Set this if you want a separate outputscale for each batch of input data. It should be `b` if x1 is a `b x n x d` tensor. Default: `torch.Size([])` outputscale_prior (Prior, optional): Set this if you want to apply a prior to the outputscale parameter. Default: `None` outputscale_constraint (Constraint, optional): Set this if you want to apply a constraint to the outputscale parameter. Default: `Positive`. Attributes: base_kernel (Kernel): The kernel module to be scaled. outputscale (Tensor): The outputscale parameter. Size/shape of parameter depends on the batch_shape arguments. Example: >>> x = torch.randn(10, 5) >>> base_covar_module = gpytorch.kernels.RBFKernel() >>> scaled_covar_module = gpytorch.kernels.ScaleKernel(base_covar_module) >>> covar = scaled_covar_module(x) # Output: LinearOperator of size (10 x 10) """ @property def is_stationary(self) -> bool: """ Kernel is stationary if base kernel is stationary. """ return self.base_kernel.is_stationary def __init__( self, base_kernel: Kernel, outputscale_prior: Optional[Prior] = None, outputscale_constraint: Optional[Interval] = None, **kwargs, ): if base_kernel.active_dims is not None: kwargs["active_dims"] = base_kernel.active_dims super(ScaleKernel, self).__init__(**kwargs) if outputscale_constraint is None: outputscale_constraint = Positive() self.base_kernel = base_kernel outputscale = torch.zeros(*self.batch_shape) if len(self.batch_shape) else torch.tensor(0.0) self.register_parameter(name="raw_outputscale", parameter=torch.nn.Parameter(outputscale)) if outputscale_prior is not None: if not isinstance(outputscale_prior, Prior): raise TypeError("Expected gpytorch.priors.Prior but got " + type(outputscale_prior).__name__) self.register_prior( "outputscale_prior", outputscale_prior, self._outputscale_param, self._outputscale_closure ) self.register_constraint("raw_outputscale", outputscale_constraint) def _outputscale_param(self, m): return m.outputscale def _outputscale_closure(self, m, v): return m._set_outputscale(v) @property def outputscale(self): return self.raw_outputscale_constraint.transform(self.raw_outputscale) @outputscale.setter def outputscale(self, value): self._set_outputscale(value) def _set_outputscale(self, value): if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_outputscale) self.initialize(raw_outputscale=self.raw_outputscale_constraint.inverse_transform(value)) def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params): orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params) outputscales = self.outputscale if last_dim_is_batch: outputscales = outputscales.unsqueeze(-1) if diag: outputscales = outputscales.unsqueeze(-1) return to_dense(orig_output) * outputscales else: outputscales = outputscales.view(*outputscales.shape, 1, 1) return orig_output.mul(outputscales) def num_outputs_per_input(self, x1, x2): return self.base_kernel.num_outputs_per_input(x1, x2) def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood): return self.base_kernel.prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood)