from math import pi
import torch
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
pi = torch.tensor(pi)
def mean_absolute_error(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
):
"""
Mean Absolute Error.
"""
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
return torch.abs(pred_dist.mean - test_y).mean(dim=combine_dim)
def mean_squared_error(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
squared: bool = True,
):
"""
Mean Squared Error.
"""
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
res = torch.square(pred_dist.mean - test_y).mean(dim=combine_dim)
if not squared:
return res**0.5
return res
def negative_log_predictive_density(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
):
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
return -pred_dist.log_prob(test_y) / test_y.shape[combine_dim]
def mean_standardized_log_loss(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
):
"""
Mean Standardized Log Loss.
Reference: Page No. 23,
Gaussian Processes for Machine Learning,
Carl Edward Rasmussen and Christopher K. I. Williams,
The MIT Press, 2006. ISBN 0-262-18253-X
"""
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
f_mean = pred_dist.mean
f_var = pred_dist.variance
return 0.5 * (torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)).mean(dim=combine_dim)
def quantile_coverage_error(
pred_dist: MultivariateNormal,
test_y: torch.Tensor,
quantile: float = 95.0,
):
"""
Quantile coverage error.
"""
if quantile <= 0 or quantile >= 100:
raise NotImplementedError("Quantile must be between 0 and 100")
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
standard_normal = torch.distributions.Normal(loc=0.0, scale=1.0)
deviation = standard_normal.icdf(torch.as_tensor(0.5 + 0.5 * (quantile / 100)))
lower = pred_dist.mean - deviation * pred_dist.stddev
upper = pred_dist.mean + deviation * pred_dist.stddev
n_samples_within_bounds = ((test_y > lower) * (test_y < upper)).sum(combine_dim)
fraction = n_samples_within_bounds / test_y.shape[combine_dim]
return torch.abs(fraction - quantile / 100)