#!/usr/bin/env python3
import torch
from .mean import Mean
[docs]class ConstantMeanGrad(Mean):
def __init__(self, prior=None, batch_shape=torch.Size(), **kwargs):
super(ConstantMeanGrad, self).__init__()
self.batch_shape = batch_shape
self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
if prior is not None:
self.register_prior("mean_prior", prior, "constant")
def forward(self, input):
batch_shape = torch.broadcast_shapes(self.batch_shape, input.shape[:-2])
mean = self.constant.unsqueeze(-1).expand(*batch_shape, input.size(-2), input.size(-1) + 1).contiguous()
mean[..., 1:] = 0
return mean