#!/usr/bin/env python3
import torch
from .mean import Mean
[docs]class LinearMean(Mean):
def __init__(self, input_size, batch_shape=torch.Size(), bias=True):
super().__init__()
self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
if bias:
self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))
else:
self.bias = None
def forward(self, x):
res = x.matmul(self.weights).squeeze(-1)
if self.bias is not None:
res = res + self.bias
return res