Source code for gpytorch.variational.cholesky_variational_distribution

#!/usr/bin/env python3

import torch
from linear_operator.operators import CholLinearOperator, TriangularLinearOperator

from ..distributions import MultivariateNormal
from ._variational_distribution import _VariationalDistribution


[docs]class CholeskyVariationalDistribution(_VariationalDistribution): """ A :obj:`~gpytorch.variational._VariationalDistribution` that is defined to be a multivariate normal distribution with a full covariance matrix. The most common way this distribution is defined is to parameterize it in terms of a mean vector and a covariance matrix. In order to ensure that the covariance matrix remains positive definite, we only consider the lower triangle. :param num_inducing_points: Size of the variational distribution. This implies that the variational mean should be this size, and the variational covariance matrix should have this many rows and columns. :param batch_shape: Specifies an optional batch size for the variational parameters. This is useful for example when doing additive variational inference. :param mean_init_std: (Default: 1e-3) Standard deviation of gaussian noise to add to the mean initialization. """ def __init__( self, num_inducing_points: int, batch_shape: torch.Size = torch.Size([]), mean_init_std: float = 1e-3, **kwargs, ): super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std) mean_init = torch.zeros(num_inducing_points) covar_init = torch.eye(num_inducing_points, num_inducing_points) mean_init = mean_init.repeat(*batch_shape, 1) covar_init = covar_init.repeat(*batch_shape, 1, 1) self.register_parameter(name="variational_mean", parameter=torch.nn.Parameter(mean_init)) self.register_parameter(name="chol_variational_covar", parameter=torch.nn.Parameter(covar_init)) def forward(self) -> MultivariateNormal: chol_variational_covar = self.chol_variational_covar dtype = chol_variational_covar.dtype device = chol_variational_covar.device # First make the cholesky factor is upper triangular lower_mask = torch.ones(self.chol_variational_covar.shape[-2:], dtype=dtype, device=device).tril(0) chol_variational_covar = TriangularLinearOperator(chol_variational_covar.mul(lower_mask)) # Now construct the actual matrix variational_covar = CholLinearOperator(chol_variational_covar) return MultivariateNormal(self.variational_mean, variational_covar) def initialize_variational_distribution(self, prior_dist: MultivariateNormal) -> None: self.variational_mean.data.copy_(prior_dist.mean) self.variational_mean.data.add_(torch.randn_like(prior_dist.mean), alpha=self.mean_init_std) self.chol_variational_covar.data.copy_(prior_dist.lazy_covariance_matrix.cholesky().to_dense())