Source code for gpytorch.priors.smoothed_box_prior

#!/usr/bin/env python3

import math
from numbers import Number

import torch
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all
from torch.nn import Module as TModule

from .prior import Prior
from .torch_priors import NormalPrior


[docs]class SmoothedBoxPrior(Prior): r"""A smoothed approximation of a uniform prior. Has full support on the reals and is differentiable everywhere. .. math:: \begin{equation*} B = {x: a_i <= x_i <= b_i} d(x, B) = min_{x' in B} |x - x'| pdf(x) \sim exp(- d(x, B)**2 / sqrt(2 * sigma^2)) \end{equation*} """ arg_constraints = {"sigma": constraints.positive, "a": constraints.real, "b": constraints.real} support = constraints.real has_rsample = True _validate_args = True def __init__(self, a, b, sigma=0.01, validate_args=False, transform=None): TModule.__init__(self) _a = torch.tensor(float(a)) if isinstance(a, Number) else a _a = _a.view(-1) if _a.dim() < 1 else _a _a, _b, _sigma = broadcast_all(_a, b, sigma) if not torch.all(constraints.less_than(_b).check(_a)): raise ValueError("must have that a < b (element-wise)") # TODO: Proper argument validation including broadcasting batch_shape, event_shape = _a.shape[:-1], _a.shape[-1:] # need to assign values before registering as buffers to make argument validation work self.a, self.b, self.sigma = _a, _b, _sigma super(SmoothedBoxPrior, self).__init__(batch_shape, event_shape, validate_args=validate_args) # now need to delete to be able to register buffer del self.a, self.b, self.sigma self.register_buffer("a", _a) self.register_buffer("b", _b) self.register_buffer("sigma", _sigma.clone()) self.tails = NormalPrior(torch.zeros_like(_a), _sigma, validate_args=validate_args) self._transform = transform @property def _c(self): return (self.a + self.b) / 2 @property def _r(self): return (self.b - self.a) / 2 @property def _M(self): # normalization factor to make this a probability distribution return torch.log(1 + (self.b - self.a) / (math.sqrt(2 * math.pi) * self.sigma)) def log_prob(self, x): return self._log_prob(self.transform(x)) def _log_prob(self, x): # x = "distances from box`" X = ((x - self._c).abs_() - self._r).clamp(min=0) return (self.tails.log_prob(X) - self._M).sum(-1) def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) gauss_max = 1 / (math.sqrt(2 * math.pi) * self.sigma) gauss_weight = 1 / (gauss_max * (self.b - self.a) + 1) picks = torch.bernoulli(gauss_weight.expand(shape)).to(self.a.device) uniform_eps = torch.rand(shape, dtype=self.a.dtype, device=self.a.device) uniform_samples = self.a + uniform_eps * (self.b - self.a) gaussian_eps = self.tails.rsample(sample_shape).to(self.a.device) gaussian_samples = gaussian_eps + torch.where(gaussian_eps < 0.0, self.a, self.b) return torch.where(picks > 0, gaussian_samples, uniform_samples)