from functools import partial
import numpy.ma as ma
import autograd.numpy as np
from autograd import grad
from autograd.extend import defvjp, primitive
import proxmin
import logging
from .bbox import overlapped_slices
from .component import CombinedComponent
from .model import UpdateException
logger = logging.getLogger("scarlet.blend")
@primitive
def _add_models(*models, full_model, slices):
"""Insert the models into the full model
`slices` is a tuple `(full_model_slice, model_slices)` used
to insert a model into the full_model in the region where the
two models overlap.
"""
for i in range(len(models)):
full_model[slices[i][0]] += models[i][slices[i][1]]
return full_model
def _grad_add_models(upstream_grad, *models, full_model, slices, index):
"""Gradient for a single model
The full model is just the sum of the models,
so the gradient is 1 for each model,
we just have to slice it appropriately.
"""
model = models[index]
full_model_slices = slices[index][0]
model_slices = slices[index][1]
def result(upstream_grad):
_result = np.zeros(model.shape, dtype=model.dtype)
_result[model_slices] = upstream_grad[full_model_slices]
return _result
return result
[docs]class Blend(CombinedComponent):
"""The blended scene
The class represents a scene as collection of and provides the functions to
fit it to data.
"""
def __init__(self, sources, observations):
"""Constructor
Form a blended scene from a collection of `~scarlet.component.Component`s
Parameters
----------
sources: list of `~scarlet.component.Component` or `~scarlet.component.ComponentTree`
Intitialized components or sources to fit to the observations
observations: a `scarlet.Observation` instance or a list thereof
Data package(s) to fit
"""
if hasattr(sources, "__iter__"):
self.sources = sources
else:
self.sources = (sources,)
if hasattr(observations, "__iter__"):
self.observations = observations
else:
self.observations = (observations,)
super().__init__(self.sources)
# only for backward compatibility, use log_likelihood instead
self.loss = []
[docs] def fit(self, max_iter=200, e_rel=1e-3, min_iter=1, noise_factor=0, **alg_kwargs):
"""Fit the model for each source to the data
Parameters
----------
max_iter: int
Maximum number of iterations if the algorithm doesn't converge
e_rel: float
Relative error for convergence of the loss function
min_iter: int
Maximum number of iterations if the algorithm doesn't converge
alg_kwargs: dict
Keywords for the `proxmin.adaprox` optimizer
"""
it = 0
self._noise_factor = noise_factor
while it < max_iter:
try:
X = self.parameters + tuple(
p for obs in self.observations for p in obs.parameters
)
# compute the backward gradients
# but only for non-fixed parameters
require_grad = tuple(k for k, x in enumerate(X) if not x.fixed)
def expand_grads(*X, func=None):
G = func(*X)
expanded = [0.0] * len(X)
for k, j in enumerate(require_grad):
expanded[j] = G[k]
return expanded
grad_logL_func = grad(self._loss_func, require_grad)
grad_logL = lambda *X: expand_grads(*X, func=grad_logL_func)
# same for prior. easier her bc we call them independently
grad_logP = lambda *X: tuple(
x.prior(x.view(np.ndarray))
if x.prior is not None and not x.fixed
else 0
for x in X
)
# combine for log posterior
_grad = lambda *X: tuple(
l + p for l, p in zip(grad_logL(*X), grad_logP(*X))
)
# step sizes
_step = lambda *X, it: tuple(
x.step(x, it=it) if hasattr(x.step, "__call__") else x.step
for x in X
)
# proxes
_prox = tuple(x.constraint for x in X)
# good defaults for adaprox
scheme = alg_kwargs.pop("scheme", "amsgrad")
prox_max_iter = alg_kwargs.pop("prox_max_iter", 10)
callback = partial(
self._callback,
e_rel=e_rel,
callback=alg_kwargs.pop("callback", None),
min_iter=min_iter,
)
# do we have a current state of the optimizer to warm start?
for x in X:
if x.m is None:
x.m = np.zeros(x.shape)
if x.v is None:
x.v = np.zeros(x.shape)
if x.vhat is None:
x.vhat = np.zeros(x.shape)
M = tuple(x.m for x in X)
V = tuple(x.v for x in X)
Vhat = tuple(x.vhat for x in X)
proxmin.adaprox(
X,
_grad,
_step,
prox=_prox,
max_iter=max_iter - it,
e_rel=e_rel,
check_convergence=False,
scheme=scheme,
prox_max_iter=prox_max_iter,
callback=callback,
M=M,
V=V,
Vhat=Vhat,
**alg_kwargs
)
logger.info(
"scarlet ran for {0} iterations to logL = {1}".format(
len(self.log_likelihood), self.log_likelihood[-1]
)
)
# set convergence and standard deviation from optimizer
for p, m, v, vhat in zip(X, M, V, Vhat):
p.std = 1 / np.sqrt(
ma.masked_equal(v, 0)
) # this is rough estimate!
return len(self.log_likelihood), self.log_likelihood[-1]
# model update forces restart
except UpdateException:
it = len(self.log_likelihood)
[docs] def get_model(self, *parameters, frame=None):
"""Get the model of the entire blend
Parameters
----------
parameters: tuple of optimization parameters
frame: `scarlet.Frame`
Alternative Frame to project the model into
Returns
-------
model: array
(Bands, Height, Width) data cube
"""
# boxed models of every source
models = self.get_models_of_children(*parameters, frame=None)
if frame is None:
frame = self.frame
# if this is the model frame then the slices are already cached
if frame == self.frame:
slices = tuple(
(src._model_frame_slices, src._model_slices) for src in self.sources
)
else:
slices = tuple(
overlapped_slices(frame.bbox, src.bbox) for src in self.sources
)
# We have to declare the function that inserts sources
# into the blend with autograd.
# This has to be done each time we fit a blend,
# since the number of components => the number of arguments,
# which must be linked to the autograd primitive function.
defvjp(
_add_models,
*([partial(_grad_add_models, index=k) for k in range(len(self.sources))])
)
full_model = np.zeros(frame.shape, dtype=frame.dtype)
full_model = _add_models(*models, full_model=full_model, slices=slices)
return full_model
@property
def log_likelihood(self):
"""Log likelihood at each iteration
The fitting method computes and sums the negative log-likelihood for the each
observation given the current model as loss function for the optimization.
Returns
-------
log_likelihood: array of (positive) log-likelihood
"""
return -np.array(self.loss)
def _loss_func(self, *parameters):
n_params = len(self.parameters)
model = self.get_model(*parameters[:n_params], frame=self.frame)
# Caculate the total loss function from all of the observations
total_loss = 0
for observation in self.observations:
n_obs_params = len(observation.parameters)
obs_params = parameters[n_params : n_params + n_obs_params]
total_loss = total_loss - observation.get_log_likelihood(
model, *obs_params, noise_factor=self._noise_factor
)
n_params += n_obs_params
self.loss.append(total_loss._value)
return total_loss
def _callback(self, *parameters, it=None, e_rel=1e-3, callback=None, min_iter=1):
# raises ArithmeticError if some of the parameters have become inf/nan
for src in self.sources:
src.check_parameters()
# raises UpdateException if model updates require optimimzation interruption
throw_exception = False
if it > 0 and it % 10 == 0:
for src in self.sources:
try:
src.update()
except UpdateException:
throw_exception = True
pass
if throw_exception:
raise UpdateException
if it > min_iter and abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(
self.loss[-1]
):
raise StopIteration(
"scarlet.Blend.fit() converged"
) # clean return from proxmin
if callback is not None:
callback(*parameters, it=it)
@property
def bbox(self):
"""Bounding box of the blend
"""
return self.frame.bbox