Source code for starred.deconvolution.parameters

import jax.numpy as jnp
import numpy as np

from starred.utils.parameters import Parameters

__all__ = ['ParametersDeconv']


[docs] class ParametersDeconv(Parameters): """ Deconvolution parameters class. """ param_names_analytic = ['a', 'c_x', 'c_y', 'dx', 'dy', 'alpha'] param_names_background = ['mean', 'h'] param_names_sersic = ['amp', 'R_sersic', 'n_sersic', 'center_x', 'center_y', 'e1', 'e2'] def __init__(self, kwargs_init, kwargs_fixed, kwargs_up=None, kwargs_down=None): """ :param kwargs_init: dictionary with information on the initial values of the parameters :param kwargs_fixed: dictionary containing the fixed parameters :param kwargs_up: dictionary with information on the upper bounds of the parameters :param kwargs_down: dictionary with information on the lower bounds of the parameters """ super(ParametersDeconv, self).__init__(kwargs_init, kwargs_fixed, kwargs_up=kwargs_up, kwargs_down=kwargs_down) # we'll guess the number of epochs, etc. based on the provided kwargs self.epochs = len(kwargs_init['kwargs_analytic']['dx']) self.M = len(kwargs_init['kwargs_analytic']['c_x']) self.N_sersic = len(kwargs_init['kwargs_sersic']['amp']) self.background_param_number = len(kwargs_init['kwargs_background']['h']) self._update_arrays()
[docs] def args2kwargs(self, args): """Obtain a dictionary of keyword arguments from positional arguments.""" i = 0 kwargs_analytic, i = self._get_params(args, i, 'kwargs_analytic') kwargs_background, i = self._get_params(args, i, 'kwargs_background') kwargs_sersic, i = self._get_params(args, i, 'kwargs_sersic') # wrap-up kwargs = {'kwargs_analytic': kwargs_analytic, 'kwargs_background': kwargs_background, 'kwargs_sersic': kwargs_sersic} return kwargs
[docs] def kwargs2args(self, kwargs): """Obtain an array of positional arguments from a dictionary of keyword arguments.""" args = self._set_params(kwargs, 'kwargs_analytic') args += self._set_params(kwargs, 'kwargs_background') args += self._set_params(kwargs, 'kwargs_sersic') return jnp.array(args)
[docs] def get_param_names_for_model(self, kwargs_key): """Returns the names of the parameters according to the key provided.""" if kwargs_key == 'kwargs_analytic': return self.param_names_analytic elif kwargs_key == 'kwargs_sersic': return self.param_names_sersic return self.param_names_background
def _get_params(self, args, i, kwargs_key): """Getting the parameters.""" kwargs = {} kwargs_fixed_k = self._kwargs_fixed[kwargs_key] param_names = self.get_param_names_for_model(kwargs_key) for name in param_names: if not name in kwargs_fixed_k: if name == 'a': num_param = self.M * self.epochs elif name == 'h': num_param = self.background_param_number elif name == 'mean': num_param = self.epochs elif name == 'dx' or name == 'dy' or name == 'alpha': num_param = self.epochs elif name == 'c_x' or name == 'c_y': num_param = self.M elif name in self.param_names_sersic: num_param = self.N_sersic else: raise NameError(f"The ParametersDeconv class does not know about this parameter: {name}.") kwargs[name] = args[i:i + num_param] i += num_param else: kwargs[name] = kwargs_fixed_k[name] free_ind = self._kwargs_free_indices[kwargs_key][name] if len(free_ind) > 0: num_param = len(free_ind) kwargs[name] = jnp.array(kwargs[name]).at[free_ind].set(args[i:i + num_param]) i += num_param return kwargs, i
[docs] def get_all_free_param_names(self, kwargs): args = self._param_names(kwargs, 'kwargs_analytic') args += self._param_names(kwargs, 'kwargs_background') args += self._param_names(kwargs, 'kwargs_sersic') return args