#!/usr/bin/env python3
import math
import numpy as np
import torch
from torch.nn import Module
from .. import settings
def _pad_with_singletons(obj: torch.Tensor, num_singletons_before: int = 0, num_singletons_after: int = 0):
"""
Pad obj with singleton dimensions on the left and right
Example:
>>> x = torch.randn(10, 5)
>>> _pad_width_singletons(x, 2, 3).shape
>>> # [1, 1, 10, 5, 1, 1, 1]
"""
new_shape = [1] * num_singletons_before + list(obj.shape) + [1] * num_singletons_after
return obj.view(*new_shape)
[docs]class GaussHermiteQuadrature1D(Module):
"""
Implements Gauss-Hermite quadrature for integrating a function with respect to several 1D Gaussian distributions
in batch mode. Within GPyTorch, this is useful primarily for computing expected log likelihoods for variational
inference.
This is implemented as a Module because Gauss-Hermite quadrature has a set of locations and weights that it
should initialize one time, but that should obey parent calls to .cuda(), .double() etc.
"""
def __init__(self, num_locs=None):
super().__init__()
if num_locs is None:
num_locs = settings.num_gauss_hermite_locs.value()
self.num_locs = num_locs
locations, weights = self._locs_and_weights(num_locs)
self.locations = locations
self.weights = weights
def _apply(self, fn):
self.locations = fn(self.locations)
self.weights = fn(self.weights)
return super(GaussHermiteQuadrature1D, self)._apply(fn)
def _locs_and_weights(self, num_locs):
"""
Get locations and weights for Gauss-Hermite quadrature. Note that this is **not** intended to be used
externally, because it directly creates tensors with no knowledge of a device or dtype to cast to.
Instead, create a GaussHermiteQuadrature1D object and get the locations and weights from buffers.
"""
locations, weights = np.polynomial.hermite.hermgauss(num_locs)
locations = torch.Tensor(locations)
weights = torch.Tensor(weights)
return locations, weights
[docs] def forward(self, func, gaussian_dists):
"""
Runs Gauss-Hermite quadrature on the callable func, integrating against the Gaussian distributions specified
by gaussian_dists.
Args:
- func (callable): Function to integrate
- gaussian_dists (Distribution): Either a MultivariateNormal whose covariance is assumed to be diagonal
or a :obj:`torch.distributions.Normal`.
Returns:
- Result of integrating func against each univariate Gaussian in gaussian_dists.
"""
means = gaussian_dists.mean
variances = gaussian_dists.variance
locations = _pad_with_singletons(self.locations, num_singletons_before=0, num_singletons_after=means.dim())
shifted_locs = torch.sqrt(2.0 * variances) * locations + means
log_probs = func(shifted_locs)
weights = _pad_with_singletons(self.weights, num_singletons_before=0, num_singletons_after=log_probs.dim() - 1)
res = (1 / math.sqrt(math.pi)) * (log_probs * weights)
res = res.sum(tuple(range(self.locations.dim())))
return res