from functools import partial
import numpy.ma as ma
import autograd.numpy as np
from autograd import grad
from autograd.extend import defvjp, primitive
import proxmin
import logging
from .component import CombinedComponent
from .model import UpdateException
logger = logging.getLogger("scarlet.blend")
@primitive
def _add_models(*models, full_model, slices):
"""Insert the models into the full model
`slices` is a tuple `(full_model_slice, model_slices)` used
to insert a model into the full_model in the region where the
two models overlap.
"""
for i in range(len(models)):
full_model[slices[i][0]] += models[i][slices[i][1]]
return full_model
def _grad_add_models(upstream_grad, *models, full_model, slices, index):
"""Gradient for a single model
The full model is just the sum of the models,
so the gradient is 1 for each model,
we just have to slice it appropriately.
"""
model = models[index]
full_model_slices = slices[index][0]
model_slices = slices[index][1]
def result(upstream_grad):
_result = np.zeros(model.shape, dtype=model.dtype)
_result[model_slices] = upstream_grad[full_model_slices]
return _result
return result
[docs]class Blend(CombinedComponent):
"""The blended scene
The class represents a scene as collection of and provides the functions to
fit it to data.
"""
def __init__(self, sources, observations):
"""Constructor
Form a blended scene from a collection of `~scarlet.component.Component`s
Parameters
----------
sources: list of `~scarlet.component.Component` or `~scarlet.component.ComponentTree`
Intitialized components or sources to fit to the observations
observations: a `scarlet.Observation` instance or a list thereof
Data package(s) to fit
"""
if hasattr(sources, "__iter__"):
self.sources = sources
else:
self.sources = (sources,)
if hasattr(observations, "__iter__"):
self.observations = observations
else:
self.observations = (observations,)
super().__init__(self.sources)
# only for backward compatibility, use log_likelihood instead
self.loss = []
[docs] def fit(self, max_iter=200, e_rel=1e-3, min_iter=1, noise_factor=0, **alg_kwargs):
"""Fit the model for each source to the data
Parameters
----------
max_iter: int
Maximum number of iterations if the algorithm doesn't converge
e_rel: float
Relative error for convergence of the loss function
min_iter: int
Maximum number of iterations if the algorithm doesn't converge
alg_kwargs: dict
Keywords for the `proxmin.adaprox` optimizer
"""
it = 0
self._noise_factor = noise_factor
while it < max_iter:
try:
X = self.parameters + tuple(
p for obs in self.observations for p in obs.parameters
)
# compute the backward gradients
# but only for non-fixed parameters
require_grad = tuple(k for k, x in enumerate(X) if not x.fixed)
def expand_grads(*X, func=None):
G = func(*X)
expanded = [0.0] * len(X)
for k, j in enumerate(require_grad):
expanded[j] = G[k]
return expanded
grad_logL_func = grad(self._loss_func, require_grad)
grad_logL = lambda *X: expand_grads(*X, func=grad_logL_func)
# same for prior. easier her bc we call them independently
grad_logP = lambda *X: tuple(
x.prior(x.view(np.ndarray))
if x.prior is not None and not x.fixed
else 0
for x in X
)
# combine for log posterior
_grad = lambda *X: tuple(
l + p for l, p in zip(grad_logL(*X), grad_logP(*X))
)
# step sizes, allow for random skipping of parameters
_step = lambda *X, it: tuple(
x.step(x, it=it) if hasattr(x.step, "__call__") else x.step
for x in X
)
_prox = tuple(x.constraint for x in X)
# good defaults for adaprox
scheme = alg_kwargs.pop("scheme", "amsgrad")
prox_max_iter = alg_kwargs.pop("prox_max_iter", 10)
callback = partial(
self._callback,
e_rel=e_rel,
callback=alg_kwargs.pop("callback", None),
min_iter=min_iter,
)
# do we have a current state of the optimizer to warm start?
for x in X:
if x.m is None:
x.m = np.zeros(x.shape)
if x.v is None:
x.v = np.zeros(x.shape)
if x.vhat is None:
x.vhat = np.zeros(x.shape)
M = tuple(x.m for x in X)
V = tuple(x.v for x in X)
Vhat = tuple(x.vhat for x in X)
proxmin.adaprox(
X,
_grad,
_step,
prox=_prox,
max_iter=max_iter - it,
e_rel=e_rel,
check_convergence=False,
scheme=scheme,
prox_max_iter=prox_max_iter,
callback=callback,
M=M,
V=V,
Vhat=Vhat,
**alg_kwargs
)
logger.info(
"scarlet ran for {0} iterations to logL = {1}".format(
len(self.log_likelihood), self.log_likelihood[-1]
)
)
# set convergence and standard deviation from optimizer
for p, m, v, vhat in zip(X, M, V, Vhat):
p.std = 1 / np.sqrt(
ma.masked_equal(v, 0)
) # this is rough estimate!
return len(self.log_likelihood), self.log_likelihood[-1]
# model update forces restart
except UpdateException:
it = len(self.log_likelihood)
[docs] def get_model(self, *parameters, frame=None):
"""Get the model of the entire blend
Parameters
----------
parameters: tuple of optimization parameters
frame: `scarlet.Frame`
Alternative Frame to project the model into
Returns
-------
model: array
(Bands, Height, Width) data cube
"""
# boxed models of every source
models = self.get_models_of_children(*parameters, frame=None)
if frame is None:
frame = self.frame
# if this is the model frame then the slices are already cached
if frame == self.frame:
slices = tuple(
(src._model_frame_slices, src._model_slices) for src in self.sources
)
else:
slices = tuple(
overlapped_slices(frame.bbox, src.bbox) for src in self.sources
)
# We have to declare the function that inserts sources
# into the blend with autograd.
# This has to be done each time we fit a blend,
# since the number of components => the number of arguments,
# which must be linked to the autograd primitive function.
defvjp(
_add_models,
*([partial(_grad_add_models, index=k) for k in range(len(self.sources))])
)
full_model = np.zeros(frame.shape, dtype=frame.dtype)
full_model = _add_models(*models, full_model=full_model, slices=slices)
return full_model
@property
def log_likelihood(self):
"""Log likelihood at each iteration
The fitting method computes and sums the negative log-likelihood for the each
observation given the current model as loss function for the optimization.
Returns
-------
log_likelihood: array of (positive) log-likelihood
"""
return -np.array(self.loss)
def _loss_func(self, *parameters):
n_params = len(self.parameters)
model = self.get_model(*parameters[:n_params], frame=self.frame)
# Caculate the total loss function from all of the observations
total_loss = 0
for observation in self.observations:
n_obs_params = len(observation.parameters)
obs_params = parameters[n_params : n_params + n_obs_params]
total_loss = total_loss - observation.get_log_likelihood(
model, *obs_params, noise_factor=self._noise_factor
)
n_params += n_obs_params
self.loss.append(total_loss._value)
return total_loss
def _callback(self, *parameters, it=None, e_rel=1e-3, callback=None, min_iter=1):
# raises ArithmeticError if some of the parameters have become inf/nan
for src in self.sources:
src.check_parameters()
# raises UpdateException if model updates require optimimzation interruption
throw_exception = False
if it > 0 and it % 10 == 0:
for src in self.sources:
try:
src.update()
except UpdateException:
throw_exception = True
pass
if throw_exception:
raise UpdateException
if it > min_iter and abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(
self.loss[-1]
):
raise StopIteration(
"scarlet.Blend.fit() converged"
) # clean return from proxmin
if callback is not None:
callback(*parameters, it=it)
@property
def bbox(self):
"""Bounding box of the blend
"""
return self.frame.bbox