Source code for gpytorch.variational.natural_variational_distribution

#!/usr/bin/env python3

import abc

import torch
from linear_operator.operators import CholLinearOperator, TriangularLinearOperator
from linear_operator.utils.cholesky import psd_safe_cholesky

from ..distributions import MultivariateNormal
from ._variational_distribution import _VariationalDistribution

class _NaturalVariationalDistribution(_VariationalDistribution, abc.ABC):
    r"""Any :obj:`~gpytorch.variational._VariationalDistribution` which calculates
    natural gradients with respect to its parameters.

[docs]class NaturalVariationalDistribution(_NaturalVariationalDistribution): r"""A multivariate normal :obj:`~gpytorch.variational._VariationalDistribution`, parameterized by **natural** parameters. .. note:: The :obj:`~gpytorch.variational.NaturalVariationalDistribution` can only be used with :obj:`gpytorch.optim.NGD`, or other optimizers that follow exactly the gradient direction. Failure to do so will cause the natural matrix :math:`\mathbf \Theta_\text{mat}` to stop being positive definite, and a :obj:`~RuntimeError` will be raised. .. seealso:: The `natural gradient descent tutorial <examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.ipynb>`_ for use instructions. The :obj:`~gpytorch.variational.TrilNaturalVariationalDistribution` for a more numerically stable parameterization, at the cost of needing more iterations to make variational regression converge. :param int num_inducing_points: Size of the variational distribution. This implies that the variational mean should be this size, and the variational covariance matrix should have this many rows and columns. :param batch_shape: Specifies an optional batch size for the variational parameters. This is useful for example when doing additive variational inference. :type batch_shape: :obj:`torch.Size`, optional :param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian noise to add to the mean initialization. """ def __init__(self, num_inducing_points, batch_shape=torch.Size([]), mean_init_std=1e-3, **kwargs): super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std) scaled_mean_init = torch.zeros(num_inducing_points) neg_prec_init = torch.eye(num_inducing_points, num_inducing_points).mul(-0.5) scaled_mean_init = scaled_mean_init.repeat(*batch_shape, 1) neg_prec_init = neg_prec_init.repeat(*batch_shape, 1, 1) # eta1 and eta2 parameterization of the variational distribution self.register_parameter(name="natural_vec", parameter=torch.nn.Parameter(scaled_mean_init)) self.register_parameter(name="natural_mat", parameter=torch.nn.Parameter(neg_prec_init)) def forward(self): mean, chol_covar = _NaturalToMuVarSqrt.apply(self.natural_vec, self.natural_mat) res = MultivariateNormal(mean, CholLinearOperator(TriangularLinearOperator(chol_covar))) return res def initialize_variational_distribution(self, prior_dist): prior_prec = prior_dist.covariance_matrix.inverse() prior_mean = prior_dist.mean noise = torch.randn_like(prior_mean).mul_(self.mean_init_std) @ prior_mean.unsqueeze(-1)).squeeze(-1).add_(noise))
def _triangular_inverse(A, upper=False): eye = torch.eye(A.size(-1), dtype=A.dtype, device=A.device) return torch.linalg.solve_triangular(A, eye, upper=upper) def _phi_for_cholesky_(A): "Modifies A to be the phi function used in differentiating through Cholesky" A.tril_().diagonal(offset=0, dim1=-2, dim2=-1).mul_(0.5) return A def _cholesky_backward(dout_dL, L, L_inverse): # c.f. # tools/autograd/templates/Functions.cpp A = L.transpose(-1, -2) @ dout_dL phi = _phi_for_cholesky_(A) grad_input = (L_inverse.transpose(-1, -2) @ phi) @ L_inverse # Symmetrize gradient return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5) class _NaturalToMuVarSqrt(torch.autograd.Function): @staticmethod def _forward(nat_mean, nat_covar): try: L_inv = psd_safe_cholesky(-2.0 * nat_covar, upper=False) except RuntimeError as e: if str(e).startswith("cholesky"): raise RuntimeError( "Non-negative-definite natural covariance. You probably " "updated it using an optimizer other than gpytorch.optim.NGD (such as Adam). " "This is not supported." ) else: raise e L = _triangular_inverse(L_inv, upper=False) S = L.transpose(-1, -2) @ L mu = (S @ nat_mean.unsqueeze(-1)).squeeze(-1) # Two choleskys are annoying, but we don't have good support for a # LinearOperator of form L.T @ L return mu, psd_safe_cholesky(S, upper=False) @staticmethod def forward(ctx, nat_mean, nat_covar): mu, L = _NaturalToMuVarSqrt._forward(nat_mean, nat_covar) ctx.save_for_backward(mu, L) return mu, L @staticmethod def _backward(dout_dmu, dout_dL, mu, L, C): """Calculate dout/d(eta1, eta2), which are: eta1 = mu eta2 = mu*mu^T + LL^T = mu*mu^T + Sigma Thus: dout/deta1 = dout/dmu + dout/dL dL/deta1 dout/deta2 = dout/dL dL/deta1 For L = chol(eta2 - eta1*eta1^T). dout/dSigma = _cholesky_backward(dout/dL, L) dout/deta2 = dout/dSigma dSigma/deta1 = -2* (dout/dSigma) mu """ dout_dSigma = _cholesky_backward(dout_dL, L, C) dout_deta1 = dout_dmu - 2 * (dout_dSigma @ mu.unsqueeze(-1)).squeeze(-1) return dout_deta1, dout_dSigma @staticmethod def backward(ctx, dout_dmu, dout_dL): "Calculates the natural gradient with respect to nat_mean, nat_covar" mu, L = ctx.saved_tensors C = _triangular_inverse(L, upper=False) return _NaturalToMuVarSqrt._backward(dout_dmu, dout_dL, mu, L, C)