Source code for scarlet.constraint

from functools import partial

import numpy as np
import proxmin

from . import operator
from .cache import Cache


[docs]class Constraint: """Constraint base class Constraints encode expected properties of the solution. Mathematically, they are the consequence of adding potentially non-differentiable penalty functions to the model fitting loss function. As we use proximal gradient methods, all constraints act as proxmimal operators, i.e. they need to have the following signature: f(X, step) -> X' where X' is the closest point to X that satisfies the feasibility criterion of the penalty function. For reference, every operator of the `proxmin` package yields a valid `Constraint`. """ def __init__(self, f=None): """Constraint base class Parameters ---------- f: proximal mapping Signature: f(X, step) -> X' """ self.f = f def __call__(self, X, step): """Proximal mapping Parameters ---------- X: array Optimimzation parameter step: float or array of same shape as X Step size for the proximal mapping Returns ------- X': closest feasible match to X """ if self.f is not None: return self.f(X, step) return X
[docs]class ConstraintChain: """An ordered list of `Constraint`s. Uses the concept of alternating projections onto convex sets to find solutions that are feasible according to a list of constraints. Parameters ---------- constraints: list of `Constraint` repeat: int How often the constrain chain is repeated to ensure feasibility """ def __init__(self, *constraints, repeat=1): assert isinstance(repeat, int) and repeat >= 1 self.constraints = constraints self.repeat = repeat def __call__(self, X, step): for r in range(self.repeat): for c in self.constraints: X = c(X, step) return X
[docs]class PositivityConstraint(Constraint): """Allow only values not smaller than `zero`. """ def __init__(self, zero=0): self.zero = zero def __call__(self, X, step): X = np.maximum(X, self.zero) return X
[docs]class NormalizationConstraint(Constraint): def __init__(self, type="sum"): """Normalize X to unity. Parameters ---------- type: in ['sum', 'max'] Whether the sum or the maximum is set to unity. """ type = type.lower() assert type in ["sum", "max"] self.type = type def __call__(self, X, step): if self.type == "sum": X /= X.sum() else: X /= X.max() return X
[docs]class L0Constraint(Constraint): def __init__(self, thresh, type="absolute"): """L0 norm (sparsity) penalty Parameters ---------- thresh: float regularization strength type: ['relative', 'absolute'] if the penalty is expressed in units of the function value (relative) or in units of the variable X (absolute). """ super().__init__( partial(proxmin.operators.prox_hard, thresh=thresh, type=type,) )
[docs]class L1Constraint(Constraint): def __init__(self, thresh, type="absolute"): """L1 norm (sparsity) penalty Parameters ---------- thresh: regularization strength type: ['relative', 'absolute'] if the penalty is expressed in units of the function value (relative) or in units of the variable X (absolute). """ super().__init__(partial(proxmin.operators.prox_soft, thresh=thresh, type=type))
[docs]class ThresholdConstraint(Constraint): """Set a cutoff threshold for pixels below the noise Use the log histogram of pixel values to determine when the source is fitting noise. This function works well to prevent faint sources from growing large footprints but for large diffuse galaxies with a wide range of pixel values this does not work as well. The region that contains flux above the threshold is contained in `component.bboxes["thresh"]`. """ def __call__(self, X, step): thresh, _bins = self.threshold(X) return proxmin.operators.prox_hard_plus(X, step, thresh=thresh, type="absolute")
[docs] def threshold(self, morph): """Find the threshold value for a given morphology """ _morph = morph[morph > 0] _bins = 50 # Decrease the bin size for sources with a small number of pixels if _morph.size < 500: _bins = max(int(_morph.size / 10), 1) if _bins == 1: return 0, _bins hist, bins = np.histogram(np.log10(_morph).reshape(-1), _bins) cutoff = np.where(hist == 0)[0] # If all of the pixels are used there is no need to threshold if len(cutoff) == 0: return 0, _bins return 10 ** bins[cutoff[-1]], _bins
[docs]class MonotonicityConstraint(Constraint): """Make morphology monotonically decrease from the center See `~scarlet.operator.prox_monotonic` for a description of the other parameters. """ def __init__(self, neighbor_weight="flat", min_gradient=0.1, use_mask=False): self.neighbor_weight = neighbor_weight self.min_gradient = min_gradient self.use_mask = use_mask def __call__(self, morph, step): shape = morph.shape center = (shape[0] // 2, shape[1] // 2) # get prox from the cache prox_name = "operator.prox_weighted_monotonic" key = (shape, center, self.neighbor_weight, self.min_gradient) # The creation of this operator is expensive, # so load it from memory if possible. try: prox = Cache.check(prox_name, key) except KeyError: prox = operator.prox_weighted_monotonic( shape, neighbor_weight=self.neighbor_weight, min_gradient=self.min_gradient, center=center, ) Cache.set(prox_name, key, prox) # apply the prox _morph = morph.copy() result = prox(morph, step) if self.use_mask: valid, _morph, _bounds = operator.prox_monotonic_mask( _morph, step, center=center, center_radius=0, variance=0, max_iter=0, ) result[valid] = _morph[valid] return result
[docs]class MonotonicMaskConstraint(Constraint): """Make morphology monotonic by branching from the center """ def __init__(self, center, center_radius=1, variance=0.0, max_iter=3): self.center = center self.center_radius = center_radius self.variance = variance self.max_iter = max_iter self.prox = partial( operator.prox_monotonic_mask, center=center, center_radius=center_radius, variance=variance, max_iter=max_iter, ) def __call__(self, morph, step): valid, morph, bounds = self.prox(morph, step) return morph
[docs]class SymmetryConstraint(Constraint): """Make the source symmetric about its center See `~scarlet.operator.prox_uncentered_symmetry` for a description of the parameters. """ def __init__(self, strength=1): self.strength = strength def __call__(self, morph, step): return operator.prox_soft_symmetry(morph, step, strength=self.strength)
[docs]class CenterOnConstraint(Constraint): """Sets the center pixel to a tiny non-zero value """ def __init__(self, tiny=1e-6): self.tiny = tiny def __call__(self, morph, step): shape = morph.shape center = (shape[0] // 2, shape[1] // 2) morph[center] = max(morph[center], self.tiny) return morph