#!/usr/bin/env python3
from __future__ import annotations
import copy
import inspect
import itertools
import operator
from typing import Callable, Iterator, Mapping, MutableSet, Optional, TypeVar, Union
import torch
from linear_operator.operators import LinearOperator
from torch import nn, Tensor
from torch.distributions import Distribution
from .constraints import Interval
from .priors import Prior
NnModuleSelf = TypeVar("NnModuleSelf", bound=nn.Module) # TODO: replace w/ typing.Self in Python 3.11
ModuleSelf = TypeVar("ModuleSelf", bound="Module") # TODO: replace w/ typing.Self in Python 3.11
RandomModuleSelf = TypeVar("RandomModuleSelf", bound="RandomModuleMixin") # TODO: replace w/ typing.Self in Python 3.11
Closure = Callable[[NnModuleSelf], Tensor]
SettingClosure = Callable[[ModuleSelf, Union[Tensor, float]], None]
SamplesDict = Mapping[str, Union[Tensor, float]]
class RandomModuleMixin:
def initialize(self: RandomModuleSelf, **kwargs) -> RandomModuleSelf:
Set a value for a parameter
kwargs: (param_name, value) - parameter to initialize.
Can also initialize recursively by passing in the full name of a
parameter. For example if model has attribute model.likelihood,
we can initialize the noise with either
`model.initialize(**{'likelihood.noise': 0.1})`
The former method would allow users to more easily store the
initialization values as one object.
Value must be a Tensor
for name, value in kwargs.items():
if not isinstance(value, Tensor):
raise RuntimeError("Initialize in RandomModules can only be done with Tensor values.")
names = name.rsplit(".")
if len(names) > 1:
mod_name, param_name = names
mod = operator.attrgetter(mod_name)(self)
mod, param_name = self, name
old_param = getattr(mod, param_name)
is_property = hasattr(type(self), name) and isinstance(getattr(type(self), name), property)
if not isinstance(old_param, torch.nn.Parameter) or is_property:
# Presumably we're calling a getter that will call initialize again on the actual parameter.
setattr(mod, param_name, value.expand(old_param.shape))
delattr(mod, param_name)
setattr(mod, param_name, value.expand(old_param.shape))
return self
[docs]class Module(nn.Module):
def __init__(self):
self._added_loss_terms = {}
self._priors: dict[str, tuple[Prior, Closure, Optional[SettingClosure]]] = {}
self._constraints: dict[str, Interval] = {}
self._strict_init = True
self._load_strict_shapes = True
def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
outputs = self.forward(*inputs, **kwargs)
if isinstance(outputs, list):
return [_validate_module_outputs(output) for output in outputs]
return _validate_module_outputs(outputs)
def _clear_cache(self):
Clear any precomputed caches.
Should be implemented by any module that caches any computation at test time.
def _get_module_and_name(self, parameter_name: str) -> tuple[nn.Module, str]:
"""Get module and name from full parameter name."""
module, name = parameter_name.split(".", 1)
if module in self._modules:
return self.__getattr__(module), name
raise AttributeError(
"Invalid parameter name {}. {} has no module {}".format(parameter_name, type(self).__name__, module)
def _strict(self, value: bool) -> None:
_set_strict(self, value)
def added_loss_terms(self):
for _, strategy in self.named_added_loss_terms():
yield strategy
def forward(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
raise NotImplementedError
def constraints(self) -> Iterator[Interval]:
for _, constraint in self.named_constraints():
yield constraint
def hyperparameters(self):
for _, param in self.named_hyperparameters():
yield param
[docs] def initialize(self: ModuleSelf, **kwargs) -> ModuleSelf:
Set a value for a parameter
kwargs: (param_name, value) - parameter to initialize.
Can also initialize recursively by passing in the full name of a
parameter. For example if model has attribute model.likelihood,
we can initialize the noise with either
`model.initialize(**{'likelihood.noise': 0.1})`
The former method would allow users to more easily store the
initialization values as one object.
Value can take the form of a tensor, a float, or an int
for name, val in kwargs.items():
if isinstance(val, int):
val = float(val)
if "." in name:
module, name = self._get_module_and_name(name)
if isinstance(module, nn.ModuleList):
idx, name = name.split(".", 1)
module[int(idx)].initialize(**{name: val})
module.initialize(**{name: val})
elif not hasattr(self, name):
raise AttributeError("Unknown parameter {p} for {c}".format(p=name, c=self.__class__.__name__))
elif name not in self._parameters and name not in self._buffers:
setattr(self, name, val)
elif isinstance(val, Tensor):
constraint = self.constraint_for_parameter_name(name)
if constraint is not None and constraint.enforced and not constraint.check_raw(val):
raise RuntimeError(
"Attempting to manually set a parameter value that is out of bounds of "
f"its current constraints, {constraint}. "
"Most likely, you want to do the following:\n likelihood = GaussianLikelihood"
except RuntimeError:
if not self._strict_init:
self.__getattr__(name).data = val
elif isinstance(val, float):
constraint = self.constraint_for_parameter_name(name)
if constraint is not None and not constraint.check_raw(val):
raise RuntimeError(
"Attempting to manually set a parameter value that is out of bounds of "
f"its current constraints, {constraint}. "
"Most likely, you want to do the following:\n likelihood = GaussianLikelihood"
raise AttributeError("Type {t} not valid for initializing parameter {p}".format(t=type(val), p=name))
# Ensure value is contained in support of prior (if present)
prior_name = "_".join([name, "prior"])
if prior_name in self._priors:
prior, closure, _ = self._priors[prior_name]
except ValueError as e:
raise ValueError("Invalid input value for prior {}. Error:\n{}".format(prior_name, e))
return self
[docs] def named_added_loss_terms(self):
"""Returns an iterator over module variational strategies, yielding both
the name of the variational strategy as well as the strategy itself.
(string, VariationalStrategy): Tuple containing the name of the
strategy and the strategy
return _extract_named_added_loss_terms(module=self, memo=None, prefix="")
def named_hyperparameters(self):
from .variational._variational_distribution import _VariationalDistribution
for module_prefix, module in self.named_modules():
if not isinstance(module, _VariationalDistribution):
for elem in module.named_parameters(prefix=module_prefix, recurse=False):
yield elem
[docs] def named_priors(self) -> Iterator[tuple[str, nn.Module, Prior, Closure, SettingClosure | None]]:
"""Returns an iterator over the module's priors, yielding the name of the prior,
the prior, the associated parameter names, and the transformation callable.
(string, Module, Prior, tuple((Parameter, callable)), callable): Tuple containing:
- the name of the prior
- the parent module of the prior
- the prior
- a tuple of tuples (param, transform), one for each of the parameters associated with the prior
- the prior's transform to be called on the parameters
return _extract_named_priors(module=self, prefix="")
def named_constraints(self) -> Iterator[tuple[str, Interval]]:
return _extract_named_constraints(module=self, memo=None, prefix="")
def named_variational_parameters(self):
from .variational._variational_distribution import _VariationalDistribution
for module_prefix, module in self.named_modules():
if isinstance(module, _VariationalDistribution):
for elem in module.named_parameters(prefix=module_prefix, recurse=False):
yield elem
def register_added_loss_term(self, name):
self._added_loss_terms[name] = None
[docs] def register_parameter(self, name: str, parameter: Optional[nn.Parameter]) -> None:
Adds a parameter to the module. The parameter can be accessed as an attribute using the given name.
The name of the parameter
The parameter
super().register_parameter(name, parameter)
[docs] def register_prior(
name: str,
prior: Prior,
param_or_closure: Union[str, Closure],
setting_closure: Optional[SettingClosure] = None,
) -> None:
Adds a prior to the module. The prior can be accessed as an attribute using the given name.
The name of the prior
The prior to be registered`
Either the name of the parameter, or a closure (which upon calling evalutes a function on
the module instance and one or more parameters):
single parameter without a transform: `.register_prior("foo_prior", foo_prior, "foo_param")`
transform a single parameter (e.g. put a log-Normal prior on it):
`.register_prior("foo_prior", NormalPrior(0, 1), lambda module: torch.log(module.foo_param))`
function of multiple parameters:
`.register_prior("foo2_prior", foo2_prior, lambda module: f(module.param1, module.param2)))`
A function taking in the module instance and a tensor in (transformed) parameter space,
initializing the internal parameter representation to the proper value by applying the
inverse transform. Enables setting parametres directly in the transformed space, as well
as sampling parameter values from priors (see `sample_from_prior`)
if isinstance(param_or_closure, str):
param = param_or_closure
if param not in self._parameters and not hasattr(self, param):
raise AttributeError(
"Unknown parameter {name} for {module}".format(name=param, module=self.__class__.__name__)
+ " Make sure the parameter is registered before registering a prior."
def closure_new(module: nn.Module) -> Tensor:
return getattr(module, param)
closure = closure_new
if setting_closure is not None:
raise RuntimeError("Must specify a closure instead of a parameter name when providing setting_closure")
def setting_closure_new(module: Module, val: Union[Tensor, float]) -> None:
module.initialize(**{param: val})
setting_closure = setting_closure_new
closure = param_or_closure
if len(inspect.signature(closure).parameters) == 0:
raise ValueError(
"""As of version 1.4, `param_or_closure` must operate on a module instance. For example:
gpytorch.priors.NormalPrior(0, 1),
lambda module: module.noise.sqrt()
if inspect.isfunction(setting_closure) and len(inspect.signature(setting_closure).parameters) < 2:
raise ValueError(
"""As of version 1.4, `setting_closure` must operate on a module instance and a tensor. For example:
gpytorch.priors.LogNormalPrior(0, 1),
lambda module: module.radius,
lambda module, value: m._set_radius(value),
self.add_module(name, prior)
self._priors[name] = (prior, closure, setting_closure)
def register_constraint(self, param_name: str, constraint: Interval, replace: bool = True) -> None:
if param_name not in self._parameters:
raise RuntimeError("Attempting to register constraint for nonexistent parameter.")
constraint_name = param_name + "_constraint"
if constraint_name in self._constraints:
current_constraint = self._constraints[constraint_name]
current_constraint = None
if isinstance(current_constraint, Interval) and not replace:
new_constraint = constraint.intersect(current_constraint)
new_constraint = constraint
self.add_module(constraint_name, new_constraint)
self._constraints[constraint_name] = new_constraint
# re-initialize the parameter if the constraint specifies an initial value
if new_constraint.initial_value is not None:
self.initialize(**{param_name: new_constraint.initial_value})
def train(self, mode=True):
# If we're going in training mode, we need to clear any pre-comptued caches from eval mode
if (self.training and not mode) or mode:
return super().train(mode=mode)
def constraint_for_parameter_name(self, param_name: str) -> Interval | None:
base_module = self
base_name = param_name
while "." in base_name:
components = base_name.split(".")
submodule_name = components[0]
submodule = getattr(base_module, submodule_name)
base_module = submodule
base_name = ".".join(components[1:])
constraint_name = base_name + "_constraint"
return base_module._constraints.get(constraint_name)
except AttributeError: # submodule may not always be a gpytorch module
return None
def _load_state_hook_ignore_shapes(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
if not self._load_strict_shapes:
local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
param.data = state_dict[key].data
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
# If we're loading from a state dict, we need to clear any precomputed caches
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
def load_strict_shapes(self, value):
def apply_fn(module):
module._load_strict_shapes = value
def named_parameters_and_constraints(self) -> Iterator[tuple[str, nn.Parameter, Interval | None]]:
for name, param in self.named_parameters():
yield name, param, self.constraint_for_parameter_name(name)
[docs] def sample_from_prior(self, prior_name: str) -> None:
"""Sample parameter values from prior. Modifies the module's parameters in-place."""
if prior_name not in self._priors:
raise RuntimeError("Unknown prior name '{}'".format(prior_name))
prior, _, setting_closure = self._priors[prior_name]
if setting_closure is None:
raise RuntimeError("Must provide inverse transform to be able to sample from prior.")
setting_closure(self, prior.sample())
def to_pyro_random_module(self) -> Module:
return self.to_random_module()
def to_random_module(self) -> Module:
random_module_cls = type("_Random" + self.__class__.__name__, (RandomModuleMixin, self.__class__), {})
if not isinstance(self, random_module_cls):
new_module = copy.deepcopy(self)
new_module.__class__ = random_module_cls # hack
# Unclear if this branch would ever get used in practice, but it semantically makes sense to have.
new_module = copy.deepcopy(self)
for mname, child in new_module.named_children():
if isinstance(child, Module):
setattr(new_module, mname, child.to_random_module())
return new_module
[docs] def pyro_sample_from_prior(self) -> Module:
For each parameter in this Module and submodule that have defined priors, sample a value for that parameter
from its corresponding prior with a pyro.sample primitive and load the resulting value in to the parameter.
This method can be used in a Pyro model to conveniently define pyro sample sites for all
parameters of the model that have GPyTorch priors registered to them.
new_module = self.to_pyro_random_module()
return _pyro_sample_from_prior(module=new_module, memo=None, prefix="")
[docs] def local_load_samples(self, samples_dict: SamplesDict, memo: MutableSet[Prior], prefix: str) -> None:
Defines local behavior of this Module when loading parameters from a samples_dict generated by a Pyro
sampling mechanism.
The default behavior here should almost always be called from any overriding class. However, a class may
want to add additional functionality, such as reshaping things to account for the fact that parameters will
acquire an extra batch dimension corresponding to the number of samples drawn.
for name, (prior, _, setting_closure) in self._priors.items():
if prior is not None and prior not in memo:
if setting_closure is None:
raise RuntimeError("Must provide setting_closure to load samples.")
setting_closure(self, samples_dict[prefix + ("." if prefix else "") + name])
[docs] def pyro_load_from_samples(self, samples_dict: SamplesDict) -> None:
Convert this Module in to a batch Module by loading parameters from the given `samples_dict`. `samples_dict`
is typically produced by a Pyro sampling mechanism.
Note that the keys of the samples_dict should correspond to prior names (covar_module.outputscale_prior) rather
than parameter names (covar_module.raw_outputscale), because we will use the setting_closure associated with
the prior to properly set the unconstrained parameter.
samples_dict: Dictionary mapping *prior names* to sample values.
_pyro_load_from_samples(module=self, samples_dict=samples_dict, memo=None, prefix="")
def update_added_loss_term(self, name, added_loss_term):
from .mlls import AddedLossTerm
if not isinstance(added_loss_term, AddedLossTerm):
raise RuntimeError("added_loss_term must be a AddedLossTerm")
if name not in self._added_loss_terms.keys():
raise RuntimeError("added_loss_term {} not registered".format(name))
self._added_loss_terms[name] = added_loss_term
def variational_parameters(self):
for _, param in self.named_variational_parameters():
yield param
def _validate_module_outputs(outputs):
if isinstance(outputs, tuple):
if not all(isinstance(output, (Tensor, Distribution, LinearOperator)) for output in outputs):
raise RuntimeError(
"All outputs must be a torch.Tensor, Distribution, or LinearOperator. "
"Got {}".format([output.__class__.__name__ for output in outputs])
if len(outputs) == 1:
outputs = outputs[0]
return outputs
elif isinstance(outputs, (Tensor, Distribution, LinearOperator)):
return outputs
raise RuntimeError(
"Output must be a torch.Tensor, Distribution, or LinearOperator. Got {}".format(outputs.__class__.__name__)
def _set_strict(module: nn.Module, value: bool) -> None:
if hasattr(module, "_strict_init"):
module._strict_init = value
for mname, module_ in module.named_children():
_set_strict(module_, value)
def _pyro_sample_from_prior(
module: NnModuleSelf, memo: Optional[MutableSet[Prior]] = None, prefix: str = ""
) -> NnModuleSelf:
import pyro
except ImportError:
raise RuntimeError("Cannot call pyro_sample_from_prior without pyro installed!")
if memo is None:
memo = set()
if isinstance(module, Module):
for prior_name, (prior, closure, setting_closure) in module._priors.items():
if prior is not None and prior not in memo:
if setting_closure is None:
raise RuntimeError(
"Cannot use Pyro for sampling without a setting_closure for each prior,"
f" but the following prior had none: {prior_name}, {prior}."
prior = prior.expand(closure(module).shape)
value = pyro.sample(prefix + ("." if prefix else "") + prior_name, prior)
setting_closure(module, value)
for mname, module_ in module.named_children():
submodule_prefix = prefix + ("." if prefix else "") + mname
_pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix)
return module
def _pyro_load_from_samples(
module: nn.Module, samples_dict: SamplesDict, memo: Optional[MutableSet[Prior]] = None, prefix: str = ""
) -> None:
if memo is None:
memo = set()
if isinstance(module, Module):
module.local_load_samples(samples_dict, memo, prefix)
for mname, module_ in module.named_children():
submodule_prefix = prefix + ("." if prefix else "") + mname
_pyro_load_from_samples(module_, samples_dict, memo=memo, prefix=submodule_prefix)
def _extract_named_added_loss_terms(module, memo=None, prefix=""):
if memo is None:
memo = set()
if hasattr(module, "_added_loss_terms"):
for name, strategy in module._added_loss_terms.items():
if strategy is not None and strategy not in memo:
yield prefix + ("." if prefix else "") + name, strategy
for mname, module_ in module.named_children():
submodule_prefix = prefix + ("." if prefix else "") + mname
for name, strategy in _extract_named_added_loss_terms(module=module_, memo=memo, prefix=submodule_prefix):
yield name, strategy
def _extract_named_priors(
module: nn.Module, prefix: str = ""
) -> Iterator[tuple[str, nn.Module, Prior, Closure, SettingClosure | None]]:
if isinstance(module, Module):
for name, (prior, closure, inv_closure) in module._priors.items():
if prior is not None:
full_name = ("." if prefix else "").join([prefix, name])
yield full_name, module, prior, closure, inv_closure
for mname, module_ in module.named_children():
submodule_prefix = prefix + ("." if prefix else "") + mname
for name, parent_module, prior, closure, inv_closure in _extract_named_priors(module_, prefix=submodule_prefix):
yield name, parent_module, prior, closure, inv_closure
def _extract_named_constraints(
module: nn.Module, memo: Optional[MutableSet[Interval]] = None, prefix: str = ""
) -> Iterator[tuple[str, Interval]]:
if memo is None:
memo = set()
if isinstance(module, Module):
for name, constraint in module._constraints.items():
if constraint is not None and constraint not in memo:
full_name = ("." if prefix else "").join([prefix, name])
yield full_name, constraint
for mname, module_ in module.named_children():
submodule_prefix = prefix + ("." if prefix else "") + mname
for name, constraint in _extract_named_constraints(module_, memo=memo, prefix=submodule_prefix):
yield name, constraint