#!/usr/bin/env python3
from copy import deepcopy
import torch
from torch.nn import ModuleList
from .mean import Mean
[docs]class MultitaskMean(Mean):
"""
Convenience :class:`gpytorch.means.Mean` implementation for defining a different mean for each task in a multitask
model. Expects a list of `num_tasks` different mean functions, each of which is applied to the given data in
:func:`~gpytorch.means.MultitaskMean.forward` and returned as an `n x t` matrix of means, one for each task.
"""
def __init__(self, base_means, num_tasks):
"""
Args:
base_means (:obj:`list` or :obj:`gpytorch.means.Mean`): If a list, each mean is applied to the data.
If a single mean (or a list containing a single mean), that mean is copied `t` times.
num_tasks (int): Number of tasks. If base_means is a list, this should equal its length.
"""
super(MultitaskMean, self).__init__()
if isinstance(base_means, Mean):
base_means = [base_means]
if not isinstance(base_means, list) or (len(base_means) != 1 and len(base_means) != num_tasks):
raise RuntimeError("base_means should be a list of means of length either 1 or num_tasks")
if len(base_means) == 1:
base_means = base_means + [deepcopy(base_means[0]) for i in range(num_tasks - 1)]
self.base_means = ModuleList(base_means)
self.num_tasks = num_tasks
[docs] def forward(self, input):
"""
Evaluate each mean in self.base_means on the input data, and return as an `n x t` matrix of means.
"""
return torch.cat([sub_mean(input).unsqueeze(-1) for sub_mean in self.base_means], dim=-1)