Source code for starred.procedures.deconvolution_routines

import warnings

from starred.deconvolution.loss import Loss as deconvLoss
from starred.psf.loss import Loss as psfLoss
from starred.deconvolution.parameters import ParametersDeconv
from starred.utils.noise_utils import propagate_noise
from starred.optim.optimization import Optimizer
from starred.plots import plot_function as pltf
from starred.psf.parameters import ParametersPSF

from copy import deepcopy
import numpy as np


[docs] def multi_steps_deconvolution(data, model, parameters, sigma_2, s, subsampling_factor, fitting_sequence=[['background'], ['pts-source-astrometry'], []], optim_list=['l-bfgs-b', 'adabelief', 'adabelief'], kwargs_optim_list=None, lambda_scales=1., lambda_hf=1000., lambda_positivity_bkg=100, lambda_positivity_ps=100, lambda_pts_source=0., prior_list=None, regularize_full_model=False, update_lambda_pts_source=False, adjust_sky=False, noise_propagation='MC', regularization_terms='l1_starlet', verbose=True): """ A high-level function to run several time the optimization algorithms and ensure to find the optimal solution. :param data: 3D array containing the images, one per epoch. shape (epochs, im_size, im_size) :param model: Deconv class, deconvolution model :param parameters: ParametersDeconv class :param sigma_2: 3D array containing the noisemaps, one per epoch. shape (epochs, im_size, im_size) :param s: 3D array containing the narrow PSFs, one per epoch. shape (epochs, im_size_up, im_size_up) where im_size_up needs be a multiple of im_size. :param subsampling_factor: integer, ratio of the size of the data pixels to that of the PSF pixels. :param fitting_sequence: list, List of lists, containing the element of the model to keep fixed. Example : [['pts-source-astrometry','pts-source-photometry','background'],['pts-source-astrometry','pts-source-photometry'], ...] :param optim_list: List of optimiser. Recommended if background is kept constant : 'l-bfgs-b', 'adabelief' otherwise. :param kwargs_optim_list: List of dictionary, containing the setting for the different optimiser. :param lambda_scales: Lagrange parameter that weights intermediate scales in the transformed domain :param lambda_hf: Lagrange parameter weighting the highest frequency scale :param lambda_pts_source: Lagrange parameter regularising the pts source channel :param lambda_positivity_bkg: Lagrange parameter weighting the positivity of the background. 0 means no positivity constraint. :param lambda_positivity_ps: Lagrange parameter weighting the positivity of the point sources. 0 means no positivity constraint. :param regularize_full_model: option to regularise just the background (False, recommended) or background + point source channel (True) :param update_lambda_pts_source : bool, option to refine the points channel regularization strength at every iteration :param prior_list: list of Prior object, Gaussian prior on parameters to be applied at each step :param adjust_sky: bool, True if you want to fit some sky subtraction :param noise_propagation: 'MC' or 'SLIT', method to compute the noise propagation in wavelet space. Default: 'MC'. :param verbose: bool. Verbosity. Return model, parameters, loss, kwargs_partial_list, fig_list, LogL_list, loss_history_list """ # Check the sequence assert len(fitting_sequence) == len(optim_list), "Fitting sequence and optimiser list have different lenght !" if kwargs_optim_list is not None: assert len(fitting_sequence) == len( kwargs_optim_list), "Fitting sequence and kwargs optimiser list have different lenght !" else: warnings.warn('No optimiser kwargs provided. Default configuration is used.') kwargs_optim_list = [{} for _ in range(len(fitting_sequence))] if prior_list is None: prior_list = [None for _ in range(len(fitting_sequence))] else: assert len(fitting_sequence) == len(prior_list), "Fitting sequence and prior list have different lenght !" # setup the model kwargs_init, kwargs_fixed_default, kwargs_up, kwargs_down = deepcopy(parameters._kwargs_init), deepcopy( parameters._kwargs_fixed), \ deepcopy(parameters._kwargs_up), deepcopy(parameters._kwargs_down) if not adjust_sky: kwargs_fixed_default['kwargs_background']['mean'] = np.zeros(model.epochs) kwargs_partial_list = [kwargs_init] fig_list = [] loss_history_list = [] loss_list = [] W = propagate_noise(model, np.sqrt(sigma_2), kwargs_init, wavelet_type_list=['starlet'], method=noise_propagation, num_samples=500, seed=1, likelihood_type='chi2', verbose=False, upsampling_factor=subsampling_factor)[0] for i, steps in enumerate(fitting_sequence): kwargs_fixed = deepcopy(kwargs_fixed_default) background_free = True for fixed_feature in steps: if fixed_feature == 'pts-source-astrometry': kwargs_fixed['kwargs_analytic']['c_x'] = kwargs_partial_list[i]['kwargs_analytic']['c_x'] kwargs_fixed['kwargs_analytic']['c_y'] = kwargs_partial_list[i]['kwargs_analytic']['c_y'] elif fixed_feature == 'pts-source-photometry': kwargs_fixed['kwargs_analytic']['a'] = kwargs_partial_list[i]['kwargs_analytic']['a'] elif fixed_feature == 'background': background_free = False kwargs_fixed['kwargs_background']['h'] = kwargs_partial_list[i]['kwargs_background']['h'] elif fixed_feature == 'sersic': for k in kwargs_partial_list[i]['kwargs_sersic'].keys(): kwargs_fixed['kwargs_sersic'][k] = kwargs_partial_list[i]['kwargs_sersic'][k] elif fixed_feature == 'dithering': kwargs_fixed['kwargs_analytic']['dx'] = kwargs_partial_list[i]['kwargs_analytic']['dx'] kwargs_fixed['kwargs_analytic']['dy'] = kwargs_partial_list[i]['kwargs_analytic']['dy'] else: raise ValueError( 'Steps is not defined. Choose between "pts-source-astrometry", "pts-source-photometry", "background", "dithering" or "sersic".') # need to recompile the parameter class, since we have changed the number of free parameters parameters = ParametersDeconv(kwargs_partial_list[i], kwargs_fixed, kwargs_up=kwargs_up, kwargs_down=kwargs_down) # for speed-up we turn of the regularization to avoid the starlet decomposition, if the background is fixed if background_free: lambda_scales_eff, lambda_hf_eff, lambda_positivity_bkg_eff = deepcopy(lambda_scales), deepcopy( lambda_hf), deepcopy(lambda_positivity_bkg) else: lambda_scales_eff, lambda_hf_eff, lambda_positivity_bkg_eff = 0., 0., 0. # refine the regularization of the point source channel correction from previous iteration if i > 0 and update_lambda_pts_source and lambda_pts_source != 0: lambda_pts_prev = lambda_pts_source_eff lambda_pts_source_eff = loss.update_lambda_pts_source(kwargs_partial_list[i]) if verbose: print(f'Updating lambda_pts_source from {lambda_pts_prev} to {lambda_pts_source_eff}') else: lambda_pts_source_eff = lambda_pts_source # run the optimization print('Step %i, fixing :' % (i + 1), steps) loss = deconvLoss(data, model, parameters, sigma_2, regularization_terms=regularization_terms, regularization_strength_scales=lambda_scales_eff, regularization_strength_hf=lambda_hf_eff, regularization_strength_positivity=lambda_positivity_bkg_eff, regularization_strength_positivity_ps=lambda_positivity_ps, regularization_strength_pts_source=lambda_pts_source_eff, regularize_full_model=regularize_full_model, prior=prior_list[i], W=W, ) optim = Optimizer(loss, parameters, method=optim_list[i]) best_fit, loss_best_fit, extra_fields, runtime = optim.minimize(**kwargs_optim_list[i]) # Saving partial results kwargs_partial_steps = deepcopy(parameters.best_fit_values(as_kwargs=True)) fig_list.append( pltf.plot_deconvolution(model, data, sigma_2, s, kwargs_partial_steps, epoch=0, units='e-', cut_dict=None)) loss_history_list.append(extra_fields['loss_history']) loss_list.append(loss_best_fit) kwargs_partial_list.append(deepcopy(kwargs_partial_steps)) if verbose: print('Step %i/%i took %2.f seconds' % (i + 1, len(fitting_sequence), runtime)) print('Kwargs partial at step %i/%i' % (i + 1, len(fitting_sequence)), kwargs_partial_steps) print('Loss : ', loss_best_fit) print('Overall Reduced Chi2 : ', loss.reduced_chi2(kwargs_partial_steps)) return model, parameters, loss, kwargs_partial_list, fig_list, loss_list, loss_history_list