#!/usr/bin/env python3
import torch
from torch.distributions import Gamma, HalfCauchy, HalfNormal, LogNormal, MultivariateNormal, Normal, Uniform
from torch.nn import Module as TModule
from .prior import Prior
from .utils import _bufferize_attributes, _del_attributes
MVN_LAZY_PROPERTIES = ("covariance_matrix", "scale_tril", "precision_matrix")
[docs]class NormalPrior(Prior, Normal):
"""
Normal (Gaussian) Prior
pdf(x) = (2 * pi * sigma^2)^-0.5 * exp(-(x - mu)^2 / (2 * sigma^2))
where mu is the mean and sigma^2 is the variance.
"""
def __init__(self, loc, scale, validate_args=False, transform=None):
TModule.__init__(self)
Normal.__init__(self, loc=loc, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("loc", "scale"))
self._transform = transform
def expand(self, batch_shape):
batch_shape = torch.Size(batch_shape)
return NormalPrior(self.loc.expand(batch_shape), self.scale.expand(batch_shape))
class HalfNormalPrior(Prior, HalfNormal):
"""
Half-Normal prior.
pdf(x) = 2 * (2 * pi * scale^2)^-0.5 * exp(-x^2 / (2 * scale^2)) for x >= 0; 0 for x < 0
where scale^2 is the variance.
"""
def __init__(self, scale, validate_args=None, transform=None):
TModule.__init__(self)
HalfNormal.__init__(self, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("scale",))
self._transform = transform
def expand(self, batch_shape):
return HalfNormal(self.scale.expand(batch_shape))
class LogNormalPrior(Prior, LogNormal):
"""
Log Normal prior.
"""
def __init__(self, loc, scale, validate_args=None, transform=None):
TModule.__init__(self)
LogNormal.__init__(self, loc=loc, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("loc", "scale"))
self._transform = transform
def expand(self, batch_shape):
batch_shape = torch.Size(batch_shape)
return LogNormalPrior(self.loc.expand(batch_shape), self.scale.expand(batch_shape))
class UniformPrior(Prior, Uniform):
"""
Uniform prior.
"""
def __init__(self, a, b, validate_args=None, transform=None):
TModule.__init__(self)
Uniform.__init__(self, a, b, validate_args=validate_args)
self._transform = transform
def expand(self, batch_shape):
batch_shape = torch.Size(batch_shape)
return UniformPrior(self.low.expand(batch_shape), self.high.expand(batch_shape))
[docs]class HalfCauchyPrior(Prior, HalfCauchy):
"""
Half-Cauchy prior.
"""
def __init__(self, scale, validate_args=None, transform=None):
TModule.__init__(self)
HalfCauchy.__init__(self, scale=scale, validate_args=validate_args)
_bufferize_attributes(self, ("scale",))
self._transform = transform
def expand(self, batch_shape):
return HalfCauchyPrior(self.scale.expand(batch_shape))
[docs]class GammaPrior(Prior, Gamma):
"""Gamma Prior parameterized by concentration and rate
pdf(x) = beta^alpha / Gamma(alpha) * x^(alpha - 1) * exp(-beta * x)
were alpha > 0 and beta > 0 are the concentration and rate parameters, respectively.
"""
def __init__(self, concentration, rate, validate_args=False, transform=None):
TModule.__init__(self)
Gamma.__init__(self, concentration=concentration, rate=rate, validate_args=validate_args)
_bufferize_attributes(self, ("concentration", "rate"))
self._transform = transform
def expand(self, batch_shape):
batch_shape = torch.Size(batch_shape)
return GammaPrior(self.concentration.expand(batch_shape), self.rate.expand(batch_shape))
def __call__(self, *args, **kwargs):
return super(Gamma, self).__call__(*args, **kwargs)
[docs]class MultivariateNormalPrior(Prior, MultivariateNormal):
"""Multivariate Normal prior
pdf(x) = det(2 * pi * Sigma)^-0.5 * exp(-0.5 * (x - mu)' Sigma^-1 (x - mu))
where mu is the mean and Sigma > 0 is the covariance matrix.
"""
def __init__(
self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=False, transform=None
):
TModule.__init__(self)
MultivariateNormal.__init__(
self,
loc=loc,
covariance_matrix=covariance_matrix,
precision_matrix=precision_matrix,
scale_tril=scale_tril,
validate_args=validate_args,
)
_bufferize_attributes(self, ("loc", "_unbroadcasted_scale_tril"))
self._transform = transform
[docs] def cuda(self, device=None):
"""Applies module-level cuda() call and resets all lazy properties"""
module = self._apply(lambda t: t.cuda(device))
_del_attributes(module, MVN_LAZY_PROPERTIES)
return module
[docs] def cpu(self):
"""Applies module-level cpu() call and resets all lazy properties"""
module = self._apply(lambda t: t.cpu())
_del_attributes(module, MVN_LAZY_PROPERTIES)
return module
def expand(self, batch_shape):
batch_shape = torch.Size(batch_shape)
cov_shape = batch_shape + self.event_shape
new_loc = self.loc.expand(batch_shape)
new_scale_tril = self.scale_tril.expand(cov_shape)
return MultivariateNormalPrior(loc=new_loc, scale_tril=new_scale_tril)