import autograd.numpy as np
from autograd.extend import defvjp, primitive
from .model import Model
from . import interpolation
from .parameter import Parameter
from . import fft
from .bbox import Box, overlapped_slices
from scarlet.operators_pybind11 import apply_filter
[docs]class Renderer(Model):
def __init__(self, data_frame, model_frame, *parameters):
self.data_frame = data_frame
self.model_frame = model_frame
# mapping of model to data frame channels
self.channel_map = self.get_channel_map(data_frame, model_frame)
super().__init__(*parameters)
# renderer is a parameterized transformation function
def __call__(self, model, *parameters):
self.transform = self.get_model(*parameters)
return self.transform(model)
[docs] def get_channel_map(self, data_frame, model_frame):
"""Compute the mapping between channels in the model frame and this observation
Parameters
----------
model_frame: a `scarlet.Frame` instance
The frame to match
Returns
-------
channel_map: None, slice, or array
None for identical channels in both frames; slice for concatenated channels;
array for linear mapping of model channels onto observation channels
"""
if list(data_frame.channels) == list(model_frame.channels):
return None
channel_map = [
list(model_frame.channels).index(c) for c in list(data_frame.channels)
]
min_channel = min(channel_map)
max_channel = max(channel_map)
if max_channel + 1 - min_channel == len(channel_map):
channel_map = slice(min_channel, max_channel + 1)
return channel_map
# full-fledged linear mixing model to allow for spectrophotometry later
channel_map = np.zeros((data_frame.C, model_frame.C))
for i, c in enumerate(list(data_frame.channels)):
j = list(model_frame.channels).index(c)
assert j != -1, f"Could not find channel {c} in model frame"
channel_map[i, j] = 1
# TODO: for overlap computation:
# * turn channels into dict channel['g'] = (lambdas, R_lambdas)
# * extrapolate obs R_lambda onto model lambdas
# * compute np.dot(model_R_lambda, obs_R_lambda) for every
# combination of obs and model channels
return channel_map
[docs] def map_channels(self, model):
"""Map to model channels onto the observation channels
Parameters
----------
model: array
The hyperspectral model
Returns
-------
obs_model: array
`model` mapped onto the observation channels
"""
if self.channel_map is None:
return model
if isinstance(self.channel_map, slice):
return model[self.channel_map]
return np.dot(self.channel_map, model)
[docs]class NullRenderer(Renderer):
def __init__(self, data_frame, model_frame):
super().__init__(data_frame, model_frame)
[docs] def get_model(*parameters):
def nothing(model):
return model
return nothing
@primitive
def convolve(image, psf, bounds):
"""Convolve an image with a PSF in real space
"""
result = np.empty(image.shape, dtype=image.dtype)
for band in range(len(image)):
if hasattr(image[band], "_value"):
# This is an ArrayBox
img = image[band]._value
else:
img = image[band]
apply_filter(
img,
psf[band].reshape(-1),
bounds[0],
bounds[1],
bounds[2],
bounds[3],
result[band],
)
return result
def _grad_convolve(convolved, image, psf, slices):
"""Gradient of a real space convolution
"""
return lambda input_grad: convolve(input_grad, psf[:, ::-1, ::-1], slices)
# Register this function in autograd
defvjp(convolve, _grad_convolve)
# match the spatial shapes of model and data
@primitive
def match_shape(model, data_frame, slices):
data_slices, model_slices = slices
data_shape = data_frame.shape
# check if data get sliced
if any(
[
data_slices[d].stop - data_slices[d].start != data_shape[d]
for d in range(-2, 0)
]
):
matched = np.zeros(data_frame.shape, dtype=data_frame.dtype)
matched[data_slices] = model[model_slices]
return matched
return model[model_slices]
def _grad_match_shape(upstream_grad, model, data_frame, slices):
# just slices gradients like the model
data_slices, model_slices = slices
def result(upstream_grad):
_result = np.zeros(model.shape, dtype=model.dtype)
_result[model_slices] = upstream_grad[data_slices]
return _result
return result
defvjp(match_shape, _grad_match_shape)
[docs]class ConvolutionRenderer(Renderer):
def __init__(
self,
data_frame,
model_frame,
*parameters,
convolution_type="fft",
padding=10,
psf_shift=None,
):
if psf_shift is not None:
psf_shift = Parameter(psf_shift, name="psf_shift", step=1.0e-2)
parameters = (*parameters, psf_shift)
super().__init__(data_frame, model_frame, *parameters)
assert convolution_type in [
"real",
"fft",
], "`convolution` must be either 'real' or 'fft'"
self._convolution_type = convolution_type
# 2D spatial region covered by data
pixel_in_model_frame = data_frame.convert_pixel_to(model_frame)
# since there cannot be rotation or scaling, it's only translation
ll = np.round(pixel_in_model_frame.min(axis=0)).astype("int")
ur = np.round(pixel_in_model_frame.max(axis=0)).astype("int") + 1
bounds = (ll[0], ur[0]), (ll[1], ur[1])
# properly treats truncation in both boxes
data_box = model_frame.bbox[0] @ Box.from_bounds(*bounds)
self.slices = overlapped_slices(data_box, model_frame.bbox)
# construct diff kernel
psf_fft = fft.Fourier(data_frame.psf.get_model().astype(model_frame.dtype))
model_psf_fft = fft.Fourier(
model_frame.psf.get_model().astype(model_frame.dtype)
)
self.diff_kernel = fft.match_psf(psf_fft, model_psf_fft, padding=padding)
@property
def convolution_bounds(self):
"""Build the slices needed for convolution in real space
"""
if not hasattr(self, "_convolution_bounds"):
coords = interpolation.get_filter_coords(self.diff_kernel[0])
self._convolution_bounds = interpolation.get_filter_bounds(
coords.reshape(-1, 2)
)
return self._convolution_bounds
[docs] def convolve(self, model, convolution_type=None, psf_shift=None):
"""Convolve the model in a single band
"""
if convolution_type is None:
convolution_type = self._convolution_type
if psf_shift is not None:
kernel = fft.shift(
self.diff_kernel.image,
psf_shift,
fft_shape=None,
axes=(-2, -1),
return_Fourier=True,
)
else:
kernel = self.diff_kernel.image
if convolution_type == "real":
result = convolve(model, kernel, self.convolution_bounds)
elif convolution_type == "fft":
result = fft.convolve(fft.Fourier(model), kernel, axes=(1, 2)).image
else:
raise ValueError(
"`convolution` must be either 'real' or 'fft', got {}".format(
convolution_type
)
)
return result
def __call__(self, model, *parameters):
self.transform = self.get_model(*parameters)
return self.transform(model, *parameters)
[docs] def get_model(self, *parameters):
def transform(model, *parameters):
# restrict to observed channels
model_ = self.map_channels(model)
# get the shift
shift = self.get_parameter("psf_shift", *parameters)
# convolve observed channels
model_ = self.convolve(model_, psf_shift=shift)
# adjust spatial shapes
model_ = match_shape(model_, self.data_frame, self.slices)
return model_
return transform
[docs]class ResolutionRenderer(Renderer):
def __init__(self, data_frame, model_frame, padding=10):
super().__init__(data_frame, model_frame)
# check if data is rotated wrt to model_frame
self.angle, self.h = interpolation.get_angles(data_frame.wcs, model_frame.wcs)
self.isrot = (np.abs(self.angle[1]) ** 2) > np.finfo(float).eps
# Get pixel coordinates alinged with x and y axes of this observation
# in model frame
lr_shape = data_frame.shape[1:]
pixels = np.stack((np.arange(lr_shape[0]), np.arange(lr_shape[1])), axis=1)
coord_hr = data_frame.convert_pixel_to(model_frame, pixel=pixels)
# TODO: should coords define a _slices_for_model/data?
# lr_inside_hr = model_frame.bbox.contains(coord_hr)
# compute diff kernel in model_frame pixels
diff_psf, psf_lr_hr = self.build_diffkernel(data_frame, model_frame)
# 1D convolutions convolutions of the model are done along the smaller axis, therefore,
# psf is convolved along the frame's longer axis.
# the smaller frame axis:
self.small_axis = data_frame.Nx <= data_frame.Ny
self._fft_shape = fft._get_fft_shape(
psf_lr_hr, np.zeros(model_frame.shape), padding=3, axes=[-2, -1], max=False,
)
# Cutting diff_psf if needded and keeping the parity
if (self._fft_shape[-2] < diff_psf.shape[-2]) or (
self._fft_shape[-1] < diff_psf.shape[-1]
):
diff_psf = fft._centered(
diff_psf, np.array([diff_psf.shape[0] + 1, *self._fft_shape]) - 1
)
self.diff_kernel = fft.Fourier(
fft._pad(diff_psf.image, self._fft_shape, axes=(-2, -1))
)
# Center of the FFT shape for matched diff kernel
center_y = int(
self._fft_shape[0] / 2.0 - (self._fft_shape[0] - model_frame.Ny) / 2.0
) + ((self._fft_shape[0] % 2) != 0) * ((model_frame.Ny % 2) == 0)
center_x = int(
self._fft_shape[1] / 2.0 - (self._fft_shape[1] - model_frame.Nx) / 2.0
) - ((self._fft_shape[1] % 2) != 0) * ((model_frame.Nx % 2) == 0)
# Compute the shifts of all LR pixels into centered HR coord
# 1 ) aligned case
if not self.isrot:
axes = [int(not self.small_axis) + 1]
self.shifts = coord_hr.T
self.shifts[0] -= center_y
self.shifts[1] -= center_x
self.other_shifts = np.copy(self.shifts)
# 2) rotated case
else:
# Unrotated coordinates:
Y_unrot = (
(coord_hr[:, 0] - center_y) * self.angle[0]
- (coord_hr[:, 1] - center_x) * self.angle[1]
).reshape(lr_shape[0])
X_unrot = (
(coord_hr[:, 1] - center_x) * self.angle[0]
+ (coord_hr[:, 0] - center_y) * self.angle[1]
).reshape(lr_shape[1])
# Removing redundancy
self.Y_unrot = Y_unrot
self.X_unrot = X_unrot
if self.small_axis:
self.shifts = np.array(
[self.Y_unrot * self.angle[0], -self.Y_unrot * self.angle[1],]
)
self.other_shifts = np.array(
[self.angle[1] * self.X_unrot, self.angle[0] * self.X_unrot,]
)
else:
self.shifts = np.array(
[self.angle[1] * self.X_unrot, self.angle[0] * self.X_unrot,]
)
self.other_shifts = np.array(
[self.Y_unrot * self.angle[0], -self.Y_unrot * self.angle[1],]
)
axes = (1, 2)
# Computes the resampling/convolution matrix
resconv_op = self.sinc_shift(self.diff_kernel, self.shifts, axes)
self._resconv_op = np.array(resconv_op, dtype=model_frame.dtype) * self.h ** 2
if self.isrot:
self._resconv_op = self._resconv_op.reshape(*self._resconv_op.shape[:2], -1)
elif self.small_axis:
self._resconv_op = self._resconv_op.reshape(*self._resconv_op.shape[:2], -1)
else:
self._resconv_op = self._resconv_op.reshape(
self._resconv_op.shape[0], -1, self._resconv_op.shape[-1]
)
[docs] def build_diffkernel(self, data_frame, model_frame):
"""Builds the differential convolution kernel between the observation and the model psf
Parameters
----------
model_frame: Frame object
the frame of the model
Returns
-------
diff_psf: array
the differential psf between observation and frame psf.
"""
# Compute diff kernel at hr
wcs_hr = model_frame.wcs
wcs_lr = data_frame.wcs
# PSF models
psf_hr = model_frame.psf.get_model()
psf_lr = data_frame.psf.get_model().astype(model_frame.dtype)
# Computes spatially matching observation and model psf. The observation psf is also resampled \\
# to the model frame resolution
# Odd pad shape
pad_shape = (
np.array(
(self.data_frame.shape[-2:] + np.array(psf_lr.shape[-2:])) / 2
).astype(int)
* 2
+ 1
)
h_lr = interpolation.get_pixel_size(interpolation.get_affine(wcs_lr))
h_hr = interpolation.get_pixel_size(interpolation.get_affine(wcs_hr))
# Interpolation of the low res psf
# TODO: isn't that just inverse of self.angle, self.h?
angle, h_ratio = interpolation.get_angles(wcs_hr, wcs_lr)
psf_lr_hr = interpolation.sinc_interp_inplace(
psf_lr, h_lr, h_hr, angle, pad_shape=pad_shape
)
# Normalisation
psf_hr /= np.sum(psf_hr)
psf_lr_hr /= np.sum(psf_lr_hr)
# build diff kernel in Fourier space
diff_psf = fft.match_psf(fft.Fourier(psf_lr_hr), fft.Fourier(psf_hr))
return diff_psf, psf_hr
[docs] def sinc_shift(self, imgs, shifts, axes):
"""Performs 2 1D sinc convolutions and shifting along one rotated axis in Fourier space.
Parameters
----------
imgs: Fourier
a Fourier object of 2D data to sinc convolve and shift
to the adequate shape.
shifts: array
an array of the shift values for each line and columns of data in imgs
axes: array
Optional argument that specifies the axes along which to apply sinc convolution.
Returns
-------
result: array
the shifted and sinc convolved array in configuration space
"""
# fft
axes = tuple(np.array(axes) - 1)
fft_shape = np.array(self._fft_shape)[tuple([axes])]
imgs_fft = imgs.fft(fft_shape, np.array(axes) + 1)
transformed_shape = np.array(imgs_fft.shape[1:])
transformed_shape[tuple([axes])] = fft_shape
# frequency sampling
if len(axes) == 1:
shifter = np.array(interpolation.mk_shifter(self._fft_shape, real=True))
else:
shifter = np.array(interpolation.mk_shifter(self._fft_shape))
# Shift
if 0 in axes:
# Fourier shift
shishift = np.exp(shifter[0][np.newaxis, :] * shifts[0][:, np.newaxis])
imgs_shiftfft = (
imgs_fft[:, np.newaxis, :, :] * shishift[np.newaxis, :, :, np.newaxis]
)
fft_axes = [len(imgs_shiftfft.shape) - 2]
# Shift along the x-axis
if 1 in axes:
# Fourier shift
shishift = np.exp(shifter[1][np.newaxis, :] * shifts[1][:, np.newaxis])
imgs_shiftfft = imgs_shiftfft * shishift[np.newaxis, :, np.newaxis, :]
fft_axes = np.array(axes) + len(imgs_shiftfft.shape) - 2
inv_shape = tuple(imgs_shiftfft.shape[:2]) + tuple(transformed_shape)
elif 1 in axes:
# Fourier shift
shishift = np.exp(shifter[1][:, np.newaxis] * shifts[1][np.newaxis, :])
imgs_shiftfft = (
imgs_fft[:, :, :, np.newaxis] * shishift[np.newaxis, np.newaxis, :, :]
)
inv_shape = (
tuple([imgs_shiftfft.shape[0]])
+ tuple(transformed_shape)
+ tuple([imgs_shiftfft.shape[-1]])
)
fft_axes = [len(imgs_shiftfft.shape) - 2]
# Inverse Fourier transform.
# The n-dimensional transform could pose problem for very large data
op = fft.Fourier.from_fft(imgs_shiftfft, fft_shape, inv_shape, fft_axes).image
return op
[docs] def get_model(self, *parameters):
def transform(model):
"""Resample and convolve a model in the observation frame
Parameters
----------
model: array
The model in some other data frame.
Returns
-------
image_model: array
`model` mapped into the observation frame
"""
# restrict to observed portion
model_ = self.map_channels(model)
C = model_.shape[0]
dtype = model_.dtype
# FFT of model, padding the psf to the fast_shape size
model_ = fft.Fourier(fft._pad(model_, self._fft_shape, axes=(-2, -1)))
if self.isrot:
axes = (1, 2)
else:
axes = [int(self.small_axis) + 1]
model_conv = self.sinc_shift(model_, -self.other_shifts, axes)
# Transposes are all over the place to make arrays F-contiguous
# -> faster than np.einsum
if self.isrot:
model_conv = model_conv.reshape(*model_conv.shape[:2], -1)
if self.small_axis:
return np.array(
[
np.dot(self._resconv_op[c], model_conv[c].T)
for c in range(C)
],
dtype=dtype,
)
else:
return np.array(
[
np.dot(self._resconv_op[c], model_conv[c].T).T
for c in range(C)
],
dtype=dtype,
)
if self.small_axis:
model_conv = model_conv.reshape(
model_conv.shape[0], -1, model_conv.shape[-1]
)
return np.array(
[
np.dot(model_conv[c].T, self._resconv_op[c].T).T
for c in range(C)
],
dtype=dtype,
)
else:
model_conv = model_conv.reshape(*model_conv.shape[:2], -1)
return np.array(
[
np.dot(self._resconv_op[c].T, model_conv[c].T).T
for c in range(C)
],
dtype=dtype,
)
return transform