#!/usr/bin/env python3

from typing import Any, Optional, Union

import torch
from linear_operator import to_linear_operator
from linear_operator.operators import (
ConstantDiagLinearOperator,
DiagLinearOperator,
KroneckerProductDiagLinearOperator,
KroneckerProductLinearOperator,
LinearOperator,
RootLinearOperator,
)
from torch import Tensor
from torch.distributions import Normal

from ..constraints import GreaterThan, Interval
from ..lazy import LazyEvaluatedKernelTensor
from ..likelihoods import _GaussianLikelihoodBase, Likelihood
from ..priors import Prior
from .noise_models import FixedGaussianNoise, Noise

r"""
Base class for multi-task Gaussian Likelihoods, supporting general heteroskedastic noise models.

:param noise_covar: A model for the noise covariance. This can be a simple homoskedastic noise model, or a GP
that is to be fitted on the observed measurement errors.
:param rank: The rank of the task noise covariance matrix to fit. If rank
is set to 0, then a diagonal covariance matrix is fit.
matrix. Only used when :math:\text{rank} > 0.
:param batch_shape: Number of batches.
"""

def __init__(
self,
noise_covar: Union[Noise, FixedGaussianNoise],
rank: int = 0,
batch_shape: torch.Size = torch.Size(),
) -> None:
super().__init__(noise_covar=noise_covar)
if rank != 0:
self.tidcs: Tensor = tidcs[:, 1:]  # (1, 1) must be 1.0, no need to parameterize this
self.register_prior(
)
raise ValueError("Can only specify task_correlation_prior if rank>0")
self.rank = rank

def _eval_corr_matrix(self) -> Tensor:
fac_diag = torch.ones(*tnc.shape[:-1], self.num_tasks, device=tnc.device, dtype=tnc.dtype)
Cfac = torch.diag_embed(fac_diag)
# squared rows must sum to one for this to be a correlation matrix
C = Cfac / Cfac.pow(2).sum(dim=-1, keepdim=True).sqrt()
return C @ C.transpose(-1, -2)

def marginal(
self, function_dist: MultitaskMultivariateNormal, *params: Any, **kwargs: Any
r"""
If :math:\text{rank} = 0, adds the task noises to the diagonal of the
covariance matrix of the supplied
:obj:~gpytorch.distributions.MultivariateNormal or
:obj:~gpytorch.distributions.MultitaskMultivariateNormal.  Otherwise,
adds a rank rank covariance matrix to it.

To accomplish this, we form a new
:obj:~linear_operator.operators.KroneckerProductLinearOperator
between :math:I_{n}, an identity matrix with size equal to the data
and a (not necessarily diagonal) matrix containing the task noises
:math:D_{t}.

We also incorporate a shared noise parameter from the base
:class:gpytorch.likelihoods.GaussianLikelihood that we extend.

The final covariance matrix after this method is then
:math:\mathbf K + \mathbf D_{t} \otimes \mathbf I_{n} + \sigma^{2} \mathbf I_{nt}.

:param function_dist: Random variable whose covariance
matrix is a :obj:~linear_operator.operators.LinearOperator we intend to augment.
:rtype: gpytorch.distributions.MultitaskMultivariateNormal:
:return: A new random variable whose covariance matrix is a
:obj:~linear_operator.operators.LinearOperator with
:math:\mathbf D_{t} \otimes \mathbf I_{n} and :math:\sigma^{2} \mathbf I_{nt} added.
"""
mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix

# ensure that sumKroneckerLT is actually called
if isinstance(covar, LazyEvaluatedKernelTensor):
covar = covar.evaluate_kernel()

covar_kron_lt = self._shaped_noise_covar(
)
covar = covar + covar_kron_lt

return function_dist.__class__(mean, covar, interleaved=function_dist._interleaved)

def _shaped_noise_covar(
self, shape: torch.Size, add_noise: Optional[bool] = True, interleaved: bool = True, *params: Any, **kwargs: Any
) -> LinearOperator:
noise = ConstantDiagLinearOperator(self.noise, diag_shape=shape[-2] * self.num_tasks)
return noise

if self.rank == 0:
ckl_init = KroneckerProductDiagLinearOperator
else:
ckl_init = KroneckerProductLinearOperator

eye_lt = ConstantDiagLinearOperator(
torch.ones(*shape[:-2], 1, dtype=dtype, device=device), diag_shape=shape[-2]
)

# to add the latent noise we exploit the fact that
# I \kron D_T + \sigma^2 I_{NT} = I \kron (D_T + \sigma^2 I)
# which allows us to move the latent noise inside the task dependent noise
# thereby allowing exploitation of Kronecker structure in this likelihood.

if interleaved:
else:

return covar_kron_lt

def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Normal:
noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
noise = noise.reshape(*noise.shape[:-1], *function_samples.shape[-2:])
return base_distributions.Independent(base_distributions.Normal(function_samples, noise.sqrt()), 1)

r"""
A convenient extension of the :class:~gpytorch.likelihoods.GaussianLikelihood to the multitask setting that allows
for a full cross-task covariance structure for the noise. The fitted covariance matrix has rank rank.
If a strictly diagonal task noise covariance matrix is desired, then rank=0 should be set. (This option still
allows for a different noise parameter for each task.)

Like the Gaussian likelihood, this object can be used with exact inference.

.. note::
At least one of :attr:has_global_noise or :attr:has_task_noise should be specified.

.. note::
MultittaskGaussianLikelihood has an analytic marginal distribution.

:param noise_covar: A model for the noise covariance. This can be a simple homoskedastic noise model, or a GP
that is to be fitted on the observed measurement errors.
:param rank: The rank of the task noise covariance matrix to fit. If rank
is set to 0, then a diagonal covariance matrix is fit.
matrix. Only used when :math:\text{rank} > 0.
:param batch_shape: Number of batches.
:param has_global_noise: Whether to include a :math:\sigma^2 \mathbf I_{nt} term in the noise model.
:math:\mathbf I_n \otimes \mathbf D_T into the noise model.

:ivar torch.Tensor task_noises: (Optional) task specific noise variances (added onto the task_noise_covar)
:ivar torch.Tensor noise: (Optional) global noise variance (added onto the task_noise_covar)
"""

def __init__(
self,
rank: int = 0,
batch_shape: torch.Size = torch.Size(),
noise_prior: Optional[Prior] = None,
noise_constraint: Optional[Interval] = None,
has_global_noise: bool = True,
) -> None:
super(Likelihood, self).__init__()  # pyre-ignore[20]
if noise_constraint is None:
noise_constraint = GreaterThan(1e-4)

if not has_task_noise and not has_global_noise:
raise ValueError(
"At least one of has_task_noise or has_global_noise must be specified. "
"Attempting to specify a likelihood that has no noise terms."
)

if rank == 0:
self.register_parameter(
)
if noise_prior is not None:
raise RuntimeError("Cannot set a task_prior if rank=0")
else:
self.register_parameter(
)
self.rank = rank

if has_global_noise:
self.register_parameter(name="raw_noise", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
self.register_constraint("raw_noise", noise_constraint)
if noise_prior is not None:
self.register_prior("raw_noise_prior", noise_prior, lambda m: m.noise)

self.has_global_noise = has_global_noise

@property
def noise(self) -> Optional[Tensor]:
return self.raw_noise_constraint.transform(self.raw_noise)

@noise.setter
def noise(self, value: Union[float, Tensor]) -> None:
self._set_noise(value)

@property
if self.rank == 0:
else:
raise AttributeError("Cannot set diagonal task noises when covariance has ", self.rank, ">0")

def task_noises(self, value: Union[float, Tensor]) -> None:
if self.rank == 0:
else:
raise AttributeError("Cannot set diagonal task noises when covariance has ", self.rank, ">0")

def _set_noise(self, value: Union[float, Tensor]) -> None:
self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value))

def _set_task_noises(self, value: Union[float, Tensor]) -> None:

@property
if self.rank > 0:
else:
raise AttributeError("Cannot retrieve task noises when covariance is diagonal.")

def task_noise_covar(self, value: Tensor) -> None:
# internally uses a pivoted cholesky decomposition to construct a low rank
# approximation of the covariance
if self.rank > 0:
else:
raise AttributeError("Cannot set non-diagonal task noises when covariance is diagonal.")

def _eval_covar_matrix(self) -> Tensor:
:return: Analytic marginal :math:p(\mathbf y).