Source code for scarlet.morphology

import autograd.numpy as np
import numpy.ma as ma

from .bbox import Box
from .constraint import (
    ConstraintChain,
    L0Constraint,
    PositivityConstraint,
    MonotonicityConstraint,
    SymmetryConstraint,
    CenterOnConstraint,
    NormalizationConstraint,
)
from .frame import Frame
from .model import Model, UpdateException
from .parameter import Parameter, relative_step
from .psf import PSF
from .wavelet import Starlet, starlet_reconstruction, get_multiresolution_support
from . import fft
from . import initialization


[docs]class Morphology(Model): """Morphology base class The class describes the 2D image of the spatial dependence of `~scarlet.FactorizedComponent`. Parameters ---------- frame: `~scarlet.Frame` Characterization of the model parameters: list of `~scarlet.Parameter` bbox: `~scarlet.Box` 2D bounding box of this model """ def __init__(self, frame, *parameters, bbox=None): assert isinstance(frame, Frame) self.frame = frame if bbox is None: bbox = frame.bbox assert isinstance(bbox, Box) self.bbox = bbox super().__init__(*parameters)
[docs]class ImageMorphology(Morphology): """Morphology from an image The class uses an arbitrary image as non-parametric model. To allow for subpixel offsets, a Fourier-based shifting transformation is available. Parameters ---------- frame: `~scarlet.Frame` Characterization of the model image: 2D array or `~scarlet.Parameter` Image parameter bbox: `~scarlet.Box` 2D bounding box for focation of the image in `frame` shift: None or `~scarlet.Parameter` 2D shift parameter (in units of image pixels) resizing: bool Whether to resize the box dynamically """ def __init__( self, frame, image, bbox=None, shifting=False, shift=None, resizing=True ): if isinstance(image, Parameter): assert image.name == "image" else: constraint = PositivityConstraint() image = Parameter( image, name="image", step=relative_step, constraint=constraint ) if bbox is None: assert frame.bbox[1:].shape == image.shape bbox = Box(image.shape) else: assert bbox.shape == image.shape self.resizing = resizing self.shifting = shifting # create the shift parameter to allow for dynamic resizing if shift is None: shift = Parameter(np.zeros(2), name="shift", step=1e-2, fixed=self.shifting) else: assert shift.shape == (2,) if isinstance(shift, Parameter): assert shift.name == "shift" else: shift = Parameter(shift, name="shift", step=1e-2) parameters = (image, shift) super().__init__(frame, *parameters, bbox=bbox)
[docs] def get_model(self, *parameters): image = self.get_parameter(0, *parameters) shift = self.get_parameter(1, *parameters) if self.shifting: image = fft.shift(image, shift, return_Fourier=False) return image
[docs] def update(self): image = self._parameters[0] size = max(image.shape) if not self.resizing or image.fixed: return # shrink the box? peel the onion dist = 0 while ( np.all(image[dist, :] == 0) and np.all(image[-dist, :] == 0) and np.all(image[:, dist] == 0) and np.all(image[:, -dist] == 0) ): dist += 1 newsize = initialization.get_minimal_boxsize(size - 2 * dist) if newsize < size: dist = (size - newsize) // 2 # Create new parameter for smaller image image = Parameter( image[dist:-dist, dist:-dist], name=image.name, prior=image.prior, constraint=image.constraint, step=image.step / 2, fixed=image.fixed, m=image.m[dist:-dist, dist:-dist] if image.m is not None else None, v=image.v[dist:-dist, dist:-dist] if image.v is not None else None, vhat=image.vhat[dist:-dist, dist:-dist] if image.vhat is not None else None, ) # set new parameters self._parameters = (image,) + self._parameters[1:] # adjust bbox self.bbox.origin = tuple(o + dist for o in self.bbox.origin) self.bbox.shape = (newsize, newsize) raise UpdateException # grow the box? # because the PSF moves power across the box, the gradients at the edge # accummulate flux from beyond the box if image.m is not None: # next adam gradient update gu = -image.m / np.sqrt(np.sqrt(ma.masked_equal(image.v, 0))) * image.step gu_pull = gu * (image > 0) # check if model has flux at the edge at all edge_pull = np.array( ( gu_pull[:, 0].mean(), gu_pull[:, -1].mean(), gu_pull[0, :].mean(), gu_pull[-1, :].mean(), ) ) # 0.1 compared to 1 at center if np.any(edge_pull > 0.1): # find next larger boxsize newsize = initialization.get_minimal_boxsize(size + 1) pad_width = (newsize - size) // 2 # Create new parameter for extended image image = Parameter( np.pad(image, pad_width, mode="linear_ramp"), name=image.name, prior=image.prior, constraint=image.constraint, step=image.step / 2, fixed=image.fixed, m=np.pad(image.m, pad_width, mode="constant") if image.m is not None else None, v=np.pad(image.v, pad_width, mode="constant") if image.v is not None else None, vhat=np.pad(image.vhat, pad_width, mode="constant") if image.vhat is not None else None, ) # set new parameters self._parameters = (image,) + self._parameters[1:] # adjust bbox self.bbox.origin = tuple(o - pad_width for o in self.bbox.origin) self.bbox.shape = (newsize, newsize) raise UpdateException
[docs]class PointSourceMorphology(Morphology): """Morphology from a PSF The class uses `frame.psf` as model, evaluated at `center` Parameters ---------- frame: `~scarlet.Frame` Characterization of the model center: array or `~scarlet.Parameter` 2D center parameter (in units of frame pixels) """ def __init__(self, frame, center): assert frame.psf is not None and isinstance(frame.psf, PSF) self.psf = frame.psf # define bbox pixel_center = tuple(np.round(center).astype("int")) shift = (0, *pixel_center) bbox = self.psf.bbox + shift # parameters is simply 2D center if isinstance(center, Parameter): assert center.name == "center" self.center = center else: self.center = Parameter(center, name="center", step=3e-2) super().__init__(frame, self.center, bbox=bbox)
[docs] def get_model(self, *parameters): center = self.get_parameter(0, *parameters) box_center = np.mean(self.bbox.bounds[1:], axis=1) offset = center - box_center return self.psf.get_model(offset=offset) # no "internal" PSF parameters here
[docs]class StarletMorphology(Morphology): """Morphology from a starlet representation of an image The class uses the starlet parameterization as an overcomplete, non-parametric model. Parameters ---------- frame: `~scarlet.Frame` Characterization of the model image: 2D array Initial image to construct starlet transform bbox: `~scarlet.Box` 2D bounding box for focation of the image in `frame` threshold: float Lower bound on threshold for all but the last starlet scale """ def __init__(self, frame, image, bbox=None, threshold=0): if bbox is None: assert frame.bbox[1:].shape == image.shape bbox = Box(image.shape) # Starlet transform of morphologies (n1,n2) with 3 dimensions: (scales+1,n1,n2) self.transform = Starlet.from_image(image) # The starlet transform is the model coeffs = self.transform.coefficients # wavelet-scale norm starlet_norm = self.transform.norm # One threshold per wavelet scale: thresh*norm thresh_array = np.zeros(coeffs.shape) + threshold thresh_array *= starlet_norm[:, None, None] # We don't threshold the last scale thresh_array[-1] = 0 constraint = PositivityConstraint(thresh_array) coeffs = Parameter(coeffs, name="coeffs", step=1e-2, constraint=constraint) super().__init__(frame, coeffs, bbox=bbox)
[docs] def get_model(self, *parameters): # Takes the inverse transform of parameters as starlet coefficients coeffs = self.get_parameter(0, *parameters) return starlet_reconstruction(coeffs)
[docs]class ExtendedSourceMorphology(ImageMorphology): def __init__( self, frame, center, image, bbox=None, monotonic="angle", symmetric=False, min_grad=0, shifting=False, resizing=True, ): """Non-parametric image morphology designed for galaxies as extended sources. Parameters ---------- frame: `~scarlet.Frame` The frame of the full model center: tuple Center of the source image: `numpy.ndarray` Image of the source. bbox: `~scarlet.Box` 2D bounding box for focation of the image in `frame` monotonic: ['flat', 'angle', 'nearest'] or None Which version of monotonic decrease in flux from the center to enforce symmetric: `bool` Whether or not to enforce symmetry. min_grad: float in [0,1) Minimal radial decline for monotonicity (in units of reference pixel value) shifting: `bool` Whether or not a subpixel shift is added as optimization parameter resize: bool Whether to resize the box dynamically """ constraints = [] # backwards compatibility: monotonic was boolean if monotonic is True: monotonic = "angle" elif monotonic is False: monotonic = None if monotonic is not None: # most astronomical sources are monotonically decreasing # from their center constraints.append( MonotonicityConstraint(neighbor_weight=monotonic, min_gradient=min_grad) ) if symmetric: # have 2-fold rotation symmetry around their center ... constraints.append(SymmetryConstraint()) constraints += [ # ... and are positive emitters PositivityConstraint(), # prevent a weak source from disappearing entirely CenterOnConstraint(), # break degeneracies between sed and morphology NormalizationConstraint("max"), ] morph_constraint = ConstraintChain(*constraints) image = Parameter(image, name="image", step=1e-2, constraint=morph_constraint) self.pixel_center = np.round(center).astype("int") if shifting: shift = Parameter(center - self.pixel_center, name="shift", step=1e-1) else: shift = None self.shift = shift super().__init__( frame, image, bbox=bbox, shifting=shifting, shift=shift, resizing=resizing ) @property def center(self): if self.shift is not None: return self.pixel_center + self.shift else: return self.pixel_center