#!/usr/bin/env python3

import warnings

import torch
from linear_operator.operators import RootLinearOperator

from ..module import Module
from ._variational_strategy import _VariationalStrategy

"""
:obj:~gpytorch.variational.VariationalStrategy to produce vector-valued (multi-task)
output distributions. Each task will be independent of one another.

The output will either be a :obj:~gpytorch.distributions.MultitaskMultivariateNormal distribution
(if we wish to evaluate all tasks for each input) or a :obj:~gpytorch.distributions.MultivariateNormal
(if we wish to evaluate a single task for each input).

The base variational strategy is assumed to operate on a batch of GPs. One of the batch
dimensions corresponds to the multiple tasks.

:param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
:param int task_dim: (Default: -1) Which batch dimension is the task dimension
"""

Module.__init__(self)
self.base_variational_strategy = base_variational_strategy

@property
def prior_distribution(self):
return self.base_variational_strategy.prior_distribution

@property
def variational_distribution(self):
return self.base_variational_strategy.variational_distribution

@property
def variational_params_initialized(self):
return self.base_variational_strategy.variational_params_initialized

def kl_divergence(self):
return super().kl_divergence().sum(dim=-1)

[docs]    def __call__(self, x, task_indices=None, prior=False, **kwargs):
r"""
See :class:LMCVariationalStrategy.
"""
function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)

# Every data point will get an output for each task
if (
and self.task_dim + len(function_dist.batch_shape) < 0
):
else:
return function_dist

else:
# Each data point will get a single output corresponding to a single task

num_batch = len(function_dist.batch_shape)

shape = list(function_dist.batch_shape + function_dist.event_shape)

return MultivariateNormal(mean, covar)

"""
:obj:~gpytorch.variational.VariationalStrategy
to produce a :obj:~gpytorch.variational.MultitaskMultivariateNormal distribution.
All outputs will be independent of one another.

The base variational strategy is assumed to operate on a batch of GPs. One of the batch
dimensions corresponds to the multiple tasks.

:param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy