import autograd.numpy as np
from autograd.extend import defvjp, primitive
from .model import Model
from . import interpolation
from .parameter import Parameter
from . import fft
from .bbox import Box, overlapped_slices
from scarlet.operators_pybind11 import apply_filter
[docs]class Renderer(Model):
    def __init__(self, data_frame, model_frame, *parameters):
        self.data_frame = data_frame
        self.model_frame = model_frame
        # mapping of model to data frame channels
        self.channel_map = self.get_channel_map(data_frame, model_frame)
        super().__init__(*parameters)
    # renderer is a parameterized transformation function
    def __call__(self, model, *parameters):
        self.transform = self.get_model(*parameters)
        return self.transform(model)
[docs]    def get_channel_map(self, data_frame, model_frame):
        """Compute the mapping between channels in the model frame and this observation
        Parameters
        ----------
        model_frame: a `scarlet.Frame` instance
            The frame to match
        Returns
        -------
        channel_map: None, slice, or array
            None for identical channels in both frames; slice for concatenated channels;
            array for linear mapping of model channels onto observation channels
        """
        if list(data_frame.channels) == list(model_frame.channels):
            return None
        channel_map = [
            list(model_frame.channels).index(c) for c in list(data_frame.channels)
        ]
        min_channel = min(channel_map)
        max_channel = max(channel_map)
        if max_channel + 1 - min_channel == len(channel_map):
            channel_map = slice(min_channel, max_channel + 1)
        return channel_map
        # full-fledged linear mixing model to allow for spectrophotometry later
        channel_map = np.zeros((data_frame.C, model_frame.C))
        for i, c in enumerate(list(data_frame.channels)):
            j = list(model_frame.channels).index(c)
            assert j != -1, f"Could not find channel {c} in model frame"
            channel_map[i, j] = 1
            # TODO: for overlap computation:
            # * turn channels into dict channel['g'] = (lambdas, R_lambdas)
            # * extrapolate obs R_lambda onto model lambdas
            # * compute np.dot(model_R_lambda, obs_R_lambda) for every
            #   combination of obs and model channels
        return channel_map 
[docs]    def map_channels(self, model):
        """Map to model channels onto the observation channels
        Parameters
        ----------
        model: array
            The hyperspectral model
        Returns
        -------
        obs_model: array
            `model` mapped onto the observation channels
        """
        if self.channel_map is None:
            return model
        if isinstance(self.channel_map, slice):
            return model[self.channel_map]
        return np.dot(self.channel_map, model)  
[docs]class NullRenderer(Renderer):
    def __init__(self, data_frame, model_frame):
        super().__init__(data_frame, model_frame)
[docs]    def get_model(*parameters):
        def nothing(model):
            return model
        return nothing  
@primitive
def convolve(image, psf, bounds):
    """Convolve an image with a PSF in real space
    """
    result = np.empty(image.shape, dtype=image.dtype)
    for band in range(len(image)):
        if hasattr(image[band], "_value"):
            # This is an ArrayBox
            img = image[band]._value
        else:
            img = image[band]
        apply_filter(
            img,
            psf[band].reshape(-1),
            bounds[0],
            bounds[1],
            bounds[2],
            bounds[3],
            result[band],
        )
    return result
def _grad_convolve(convolved, image, psf, slices):
    """Gradient of a real space convolution
    """
    return lambda input_grad: convolve(input_grad, psf[:, ::-1, ::-1], slices)
# Register this function in autograd
defvjp(convolve, _grad_convolve)
# match the spatial shapes of model and data
@primitive
def match_shape(model, data_frame, slices):
    data_slices, model_slices = slices
    data_shape = data_frame.shape
    # check if data get sliced
    if any(
        [
            data_slices[d].stop - data_slices[d].start != data_shape[d]
            for d in range(-2, 0)
        ]
    ):
        matched = np.zeros(data_frame.shape, dtype=data_frame.dtype)
        matched[data_slices] = model[model_slices]
        return matched
    return model[model_slices]
def _grad_match_shape(upstream_grad, model, data_frame, slices):
    # just slices gradients like the model
    data_slices, model_slices = slices
    def result(upstream_grad):
        _result = np.zeros(model.shape, dtype=model.dtype)
        _result[model_slices] = upstream_grad[data_slices]
        return _result
    return result
defvjp(match_shape, _grad_match_shape)
[docs]class ConvolutionRenderer(Renderer):
    def __init__(
        self,
        data_frame,
        model_frame,
        *parameters,
        convolution_type="fft",
        padding=10,
        psf_shift=None,
    ):
        if psf_shift is not None:
            psf_shift = Parameter(psf_shift, name="psf_shift", step=1.0e-2)
            parameters = (*parameters, psf_shift)
        super().__init__(data_frame, model_frame, *parameters)
        assert convolution_type in [
            "real",
            "fft",
        ], "`convolution` must be either 'real' or 'fft'"
        self._convolution_type = convolution_type
        # 2D spatial region covered by data
        pixel_in_model_frame = data_frame.convert_pixel_to(model_frame)
        # since there cannot be rotation or scaling, it's only translation
        ll = np.round(pixel_in_model_frame.min(axis=0)).astype("int")
        ur = np.round(pixel_in_model_frame.max(axis=0)).astype("int") + 1
        bounds = (ll[0], ur[0]), (ll[1], ur[1])
        # properly treats truncation in both boxes
        data_box = model_frame.bbox[0] @ Box.from_bounds(*bounds)
        self.slices = overlapped_slices(data_box, model_frame.bbox)
        # construct diff kernel
        psf_fft = fft.Fourier(data_frame.psf.get_model().astype(model_frame.dtype))
        model_psf_fft = fft.Fourier(
            model_frame.psf.get_model().astype(model_frame.dtype)
        )
        self.diff_kernel = fft.match_psf(psf_fft, model_psf_fft, padding=padding)
    @property
    def convolution_bounds(self):
        """Build the slices needed for convolution in real space
        """
        if not hasattr(self, "_convolution_bounds"):
            coords = interpolation.get_filter_coords(self.diff_kernel[0])
            self._convolution_bounds = interpolation.get_filter_bounds(
                coords.reshape(-1, 2)
            )
        return self._convolution_bounds
[docs]    def convolve(self, model, convolution_type=None, psf_shift=None):
        """Convolve the model in a single band
        """
        if convolution_type is None:
            convolution_type = self._convolution_type
        if psf_shift is not None:
            kernel = fft.shift(
                self.diff_kernel.image,
                psf_shift,
                fft_shape=None,
                axes=(-2, -1),
                return_Fourier=True,
            )
        else:
            kernel = self.diff_kernel.image
        if convolution_type == "real":
            result = convolve(model, kernel, self.convolution_bounds)
        elif convolution_type == "fft":
            result = fft.convolve(fft.Fourier(model), kernel, axes=(1, 2)).image
        else:
            raise ValueError(
                "`convolution` must be either 'real' or 'fft', got {}".format(
                    convolution_type
                )
            )
        return result 
    def __call__(self, model, *parameters):
        self.transform = self.get_model(*parameters)
        return self.transform(model, *parameters)
[docs]    def get_model(self, *parameters):
        def transform(model, *parameters):
            # restrict to observed channels
            model_ = self.map_channels(model)
            # get the shift
            shift = self.get_parameter("psf_shift", *parameters)
            # convolve observed channels
            model_ = self.convolve(model_, psf_shift=shift)
            # adjust spatial shapes
            model_ = match_shape(model_, self.data_frame, self.slices)
            return model_
        return transform  
[docs]class ResolutionRenderer(Renderer):
    def __init__(self, data_frame, model_frame, padding=10):
        super().__init__(data_frame, model_frame)
        # check if data is rotated wrt to model_frame
        self.angle, self.h = interpolation.get_angles(data_frame.wcs, model_frame.wcs)
        self.isrot = (np.abs(self.angle[1]) ** 2) > np.finfo(float).eps
        # Get pixel coordinates alinged with x and y axes  of this observation
        # in model frame
        lr_shape = data_frame.shape[1:]
        pixels = np.stack((np.arange(lr_shape[0]), np.arange(lr_shape[1])), axis=1)
        coord_hr = data_frame.convert_pixel_to(model_frame, pixel=pixels)
        # TODO: should coords define a _slices_for_model/data?
        # lr_inside_hr = model_frame.bbox.contains(coord_hr)
        # compute diff kernel in model_frame pixels
        diff_psf, psf_lr_hr = self.build_diffkernel(data_frame, model_frame)
        # 1D convolutions convolutions of the model are done along the smaller axis, therefore,
        # psf is convolved along the frame's longer axis.
        # the smaller frame axis:
        self.small_axis = data_frame.Nx <= data_frame.Ny
        self._fft_shape = fft._get_fft_shape(
            psf_lr_hr, np.zeros(model_frame.shape), padding=3, axes=[-2, -1], max=False,
        )
        # Cutting diff_psf if needded and keeping the parity
        if (self._fft_shape[-2] < diff_psf.shape[-2]) or (
            self._fft_shape[-1] < diff_psf.shape[-1]
        ):
            diff_psf = fft._centered(
                diff_psf, np.array([diff_psf.shape[0] + 1, *self._fft_shape]) - 1
            )
        self.diff_kernel = fft.Fourier(
            fft._pad(diff_psf.image, self._fft_shape, axes=(-2, -1))
        )
        # Center of the FFT shape for matched diff kernel
        center_y = int(
            self._fft_shape[0] / 2.0 - (self._fft_shape[0] - model_frame.Ny) / 2.0
        ) + ((self._fft_shape[0] % 2) != 0) * ((model_frame.Ny % 2) == 0)
        center_x = int(
            self._fft_shape[1] / 2.0 - (self._fft_shape[1] - model_frame.Nx) / 2.0
        ) - ((self._fft_shape[1] % 2) != 0) * ((model_frame.Nx % 2) == 0)
        # Compute the shifts of all LR pixels into centered HR coord
        # 1 ) aligned case
        if not self.isrot:
            axes = [int(not self.small_axis) + 1]
            self.shifts = coord_hr.T
            self.shifts[0] -= center_y
            self.shifts[1] -= center_x
            self.other_shifts = np.copy(self.shifts)
        # 2) rotated case
        else:
            # Unrotated coordinates:
            Y_unrot = (
                (coord_hr[:, 0] - center_y) * self.angle[0]
                - (coord_hr[:, 1] - center_x) * self.angle[1]
            ).reshape(lr_shape[0])
            X_unrot = (
                (coord_hr[:, 1] - center_x) * self.angle[0]
                + (coord_hr[:, 0] - center_y) * self.angle[1]
            ).reshape(lr_shape[1])
            # Removing redundancy
            self.Y_unrot = Y_unrot
            self.X_unrot = X_unrot
            if self.small_axis:
                self.shifts = np.array(
                    [self.Y_unrot * self.angle[0], -self.Y_unrot * self.angle[1],]
                )
                self.other_shifts = np.array(
                    [self.angle[1] * self.X_unrot, self.angle[0] * self.X_unrot,]
                )
            else:
                self.shifts = np.array(
                    [self.angle[1] * self.X_unrot, self.angle[0] * self.X_unrot,]
                )
                self.other_shifts = np.array(
                    [self.Y_unrot * self.angle[0], -self.Y_unrot * self.angle[1],]
                )
            axes = (1, 2)
        # Computes the resampling/convolution matrix
        resconv_op = self.sinc_shift(self.diff_kernel, self.shifts, axes)
        self._resconv_op = np.array(resconv_op, dtype=model_frame.dtype) * self.h ** 2
        if self.isrot:
            self._resconv_op = self._resconv_op.reshape(*self._resconv_op.shape[:2], -1)
        elif self.small_axis:
            self._resconv_op = self._resconv_op.reshape(*self._resconv_op.shape[:2], -1)
        else:
            self._resconv_op = self._resconv_op.reshape(
                self._resconv_op.shape[0], -1, self._resconv_op.shape[-1]
            )
[docs]    def build_diffkernel(self, data_frame, model_frame):
        """Builds the differential convolution kernel between the observation and the model psf
        Parameters
        ----------
        model_frame: Frame object
            the frame of the model
        Returns
        -------
        diff_psf: array
            the differential psf between observation and frame psf.
        """
        # Compute diff kernel at hr
        wcs_hr = model_frame.wcs
        wcs_lr = data_frame.wcs
        # PSF models
        psf_hr = model_frame.psf.get_model()
        psf_lr = data_frame.psf.get_model().astype(model_frame.dtype)
        # Computes spatially matching observation and model psf. The observation psf is also resampled \\
        # to the model frame resolution
        # Odd pad shape
        pad_shape = (
            np.array(
                (self.data_frame.shape[-2:] + np.array(psf_lr.shape[-2:])) / 2
            ).astype(int)
            * 2
            + 1
        )
        h_lr = interpolation.get_pixel_size(interpolation.get_affine(wcs_lr))
        h_hr = interpolation.get_pixel_size(interpolation.get_affine(wcs_hr))
        # Interpolation of the low res psf
        # TODO: isn't that just inverse of self.angle, self.h?
        angle, h_ratio = interpolation.get_angles(wcs_hr, wcs_lr)
        psf_lr_hr = interpolation.sinc_interp_inplace(
            psf_lr, h_lr, h_hr, angle, pad_shape=pad_shape
        )
        # Normalisation
        psf_hr /= np.sum(psf_hr)
        psf_lr_hr /= np.sum(psf_lr_hr)
        # build diff kernel in Fourier space
        diff_psf = fft.match_psf(fft.Fourier(psf_lr_hr), fft.Fourier(psf_hr))
        return diff_psf, psf_hr 
[docs]    def sinc_shift(self, imgs, shifts, axes):
        """Performs 2 1D sinc convolutions and shifting along one rotated axis in Fourier space.
        Parameters
        ----------
        imgs: Fourier
            a Fourier object of 2D data to sinc convolve and shift
            to the adequate shape.
        shifts: array
            an array of the shift values for each line and columns of data in imgs
        axes: array
            Optional argument that specifies the axes along which to apply sinc convolution.
        Returns
        -------
        result: array
            the shifted and sinc convolved array in configuration space
        """
        # fft
        axes = tuple(np.array(axes) - 1)
        fft_shape = np.array(self._fft_shape)[tuple([axes])]
        imgs_fft = imgs.fft(fft_shape, np.array(axes) + 1)
        transformed_shape = np.array(imgs_fft.shape[1:])
        transformed_shape[tuple([axes])] = fft_shape
        # frequency sampling
        if len(axes) == 1:
            shifter = np.array(interpolation.mk_shifter(self._fft_shape, real=True))
        else:
            shifter = np.array(interpolation.mk_shifter(self._fft_shape))
        # Shift
        if 0 in axes:
            # Fourier shift
            shishift = np.exp(shifter[0][np.newaxis, :] * shifts[0][:, np.newaxis])
            imgs_shiftfft = (
                imgs_fft[:, np.newaxis, :, :] * shishift[np.newaxis, :, :, np.newaxis]
            )
            fft_axes = [len(imgs_shiftfft.shape) - 2]
            # Shift along the x-axis
            if 1 in axes:
                # Fourier shift
                shishift = np.exp(shifter[1][np.newaxis, :] * shifts[1][:, np.newaxis])
                imgs_shiftfft = imgs_shiftfft * shishift[np.newaxis, :, np.newaxis, :]
                fft_axes = np.array(axes) + len(imgs_shiftfft.shape) - 2
            inv_shape = tuple(imgs_shiftfft.shape[:2]) + tuple(transformed_shape)
        elif 1 in axes:
            # Fourier shift
            shishift = np.exp(shifter[1][:, np.newaxis] * shifts[1][np.newaxis, :])
            imgs_shiftfft = (
                imgs_fft[:, :, :, np.newaxis] * shishift[np.newaxis, np.newaxis, :, :]
            )
            inv_shape = (
                tuple([imgs_shiftfft.shape[0]])
                + tuple(transformed_shape)
                + tuple([imgs_shiftfft.shape[-1]])
            )
            fft_axes = [len(imgs_shiftfft.shape) - 2]
        # Inverse Fourier transform.
        # The n-dimensional transform could pose problem for very large data
        op = fft.Fourier.from_fft(imgs_shiftfft, fft_shape, inv_shape, fft_axes).image
        return op 
[docs]    def get_model(self, *parameters):
        def transform(model):
            """Resample and convolve a model in the observation frame
            Parameters
            ----------
            model: array
                The model in some other data frame.
            Returns
            -------
            image_model: array
                `model` mapped into the observation frame
            """
            # restrict to observed portion
            model_ = self.map_channels(model)
            C = model_.shape[0]
            dtype = model_.dtype
            # FFT of model, padding the psf to the fast_shape size
            model_ = fft.Fourier(fft._pad(model_, self._fft_shape, axes=(-2, -1)))
            if self.isrot:
                axes = (1, 2)
            else:
                axes = [int(self.small_axis) + 1]
            model_conv = self.sinc_shift(model_, -self.other_shifts, axes)
            # Transposes are all over the place to make arrays F-contiguous
            # -> faster than np.einsum
            if self.isrot:
                model_conv = model_conv.reshape(*model_conv.shape[:2], -1)
                if self.small_axis:
                    return np.array(
                        [
                            np.dot(self._resconv_op[c], model_conv[c].T)
                            for c in range(C)
                        ],
                        dtype=dtype,
                    )
                else:
                    return np.array(
                        [
                            np.dot(self._resconv_op[c], model_conv[c].T).T
                            for c in range(C)
                        ],
                        dtype=dtype,
                    )
            if self.small_axis:
                model_conv = model_conv.reshape(
                    model_conv.shape[0], -1, model_conv.shape[-1]
                )
                return np.array(
                    [
                        np.dot(model_conv[c].T, self._resconv_op[c].T).T
                        for c in range(C)
                    ],
                    dtype=dtype,
                )
            else:
                model_conv = model_conv.reshape(*model_conv.shape[:2], -1)
                return np.array(
                    [
                        np.dot(self._resconv_op[c].T, model_conv[c].T).T
                        for c in range(C)
                    ],
                    dtype=dtype,
                )
        return transform