Source code for scarlet.model

from abc import ABC, abstractmethod

from .parameter import Parameter


[docs]class UpdateException(Exception): pass
[docs]class Model(ABC): """Model base class. This class stores and provides access to parameters and sub-ordinate models. Parameters ---------- parameters: list of `~scarlet.Parameter` children: list of `~scarlet.Model` Subordinate models. """ def __init__(self, *parameters, children=None): if len(parameters) == 0: self._parameters = () elif isinstance(parameters, Parameter): self._parameters = (parameters,) elif isinstance(parameters, (list, tuple)): for p in parameters: assert isinstance(p, Parameter) self._parameters = parameters else: raise TypeError( "parameter must be None, a Parameter, or a list of Parameters" ) if children is None: children = () if hasattr(children, "__iter__"): for c in children: assert isinstance(c, Model) self._children = children else: assert isinstance(children, Model) self._children = tuple(children) self.check_parameters() @property def parameters(self): """List of parameters, including from the children """ return self._parameters + tuple(p for c in self.children for p in c.parameters) @property def children(self): """List of child models """ return self._children def __getitem__(self, i): return self._children.__getitem__(i) def __iter__(self): return self._children.__iter__() def __next__(self): return self._children.__next__()
[docs] def get_parameter(self, i, *parameters): """Access parameters by list index or by name Parameters ---------- i: int, slice, str Index, slice or name attribute of the requested parameter parameters: tuple Parameters used during optimization. If not set, uses `self` Returns ------- Matching item or tuple of matching items """ # NOTE: index lookup only works if order is not changed by parameter fixing! # during optimization: parameters are passed by autograd if parameters: parameters_ = parameters else: parameters_ = self.parameters if isinstance(i, (int, slice)): return parameters_[i] elif isinstance(i, str): if parameters: match = tuple(p for p in parameters_ if p._value.name == i) else: match = tuple(p for p in parameters_ if p.name == i) if len(match) == 1: match = match[0] return match return None
[docs] @abstractmethod def get_model(self, *parameters, **kwargs): """Get the model realization Parameters ---------- parameters: tuple of optimimzation parameters Returns ------- model: array Realization of the model """ pass
[docs] def get_models_of_children(self, *parameters, **kwargs): """Get realization of all child models Parameters ---------- parameters: tuple of optimimzation parameters Returns ------- model: list Realization of the child models, ordered by child index """ models = [] # parameters during optimization if len(parameters): # count non-fixed own parameters i = len(self._parameters) for c in self._children: j = len(c.parameters) models.append(c.get_model(*(parameters[i : i + j]), **kwargs)) i += j else: for c in self._children: models.append(c.get_model(**kwargs)) return models
[docs] def check_parameters(self): """Check that all parameters have finite elements Raises ------ `ArithmeticError` when non-finite elements are present """ for k, p in enumerate(self.parameters): if not p.is_finite: msg = "Model {}, Parameter '{}' is not finite:\n{}".format( self.__class__.__name__, p.name, p ) raise ArithmeticError(msg)
[docs] def update(self): """Update internal state or configuration of the model The method is only needed to adjust setting or parameters outside of the optimization forward path. Raises ------ `scarlet.model.UpdateException` if the optimization needs to be interrupted """ pass