Source code for starred.psf.parameters

import jax.numpy as jnp
import numpy as np

from starred.utils.parameters import Parameters

__all__ = ['ParametersPSF']


[docs] class ParametersPSF(Parameters): """ Point Spread Function parameters class. """ param_names_moffat = ['fwhm_x', 'fwhm_y', 'phi', 'beta', 'C'] param_names_background = ['background', 'mean'] param_names_gaussian = ['a', 'x0', 'y0'] param_names_distortion = ['dilation_x', 'dilation_y', 'shear'] def __init__(self, kwargs_init, kwargs_fixed, kwargs_up=None, kwargs_down=None, include_moffat=True): """ :param kwargs_init: dictionary with information on the initial values of the parameters :param kwargs_fixed: dictionary containing the fixed parameters :param kwargs_up: dictionary with information on the upper bounds of the parameters :param kwargs_down: dictionary with information on the lower bounds of the parameters """ super(ParametersPSF, self).__init__(kwargs_init, kwargs_fixed, kwargs_up=kwargs_up, kwargs_down=kwargs_down) # we'll guess the number of sources, etc. based on the provided kwargs self.M = len(kwargs_init['kwargs_gaussian']['x0']) # number of sources, one per provided image self.background_param_number = len(kwargs_init['kwargs_background']['background']) # ah, and ...since we're making a strong distinction between elliptical or standard moffat, we'll # remove the extra params if we're not dealing with an elliptical moffat. However, in the future # we should change the parametrization to remove the need for this distinction: # can always fix the params of the elliptical to have a circular one. # TODO if not ('phi' in kwargs_init['kwargs_moffat']): self.param_names_moffat = ['fwhm', 'beta', 'C'] if not include_moffat: self.param_names_moffat = [] self._kwargs_init['kwargs_moffat'] = {} self._kwargs_fixed['kwargs_moffat'] = {} self._kwargs_up['kwargs_moffat'] = {} self._kwargs_down['kwargs_moffat'] = {} self._update_arrays()
[docs] def args2kwargs(self, args): """Obtain a dictionary of keyword arguments from positional arguments.""" i = 0 kwargs_moffat, i = self._get_params(args, i, 'kwargs_moffat') kwargs_gaussian, i = self._get_params(args, i, 'kwargs_gaussian') kwargs_background, i = self._get_params(args, i, 'kwargs_background') kwargs_distortion, i = self._get_params(args, i, 'kwargs_distortion') # wrap-up kwargs = {'kwargs_moffat': kwargs_moffat, 'kwargs_gaussian': kwargs_gaussian, 'kwargs_background': kwargs_background, 'kwargs_distortion': kwargs_distortion} return kwargs
[docs] def kwargs2args(self, kwargs): """Obtain an array of positional arguments from a dictionary of keyword arguments.""" args = self._set_params(kwargs, 'kwargs_moffat') args += self._set_params(kwargs, 'kwargs_gaussian') args += self._set_params(kwargs, 'kwargs_background') args += self._set_params(kwargs, 'kwargs_distortion') return jnp.array(args)
[docs] def get_param_names_for_model(self, kwargs_key): """Returns the names of the parameters according to the key provided.""" if kwargs_key == 'kwargs_moffat': return self.param_names_moffat elif kwargs_key == 'kwargs_gaussian': return self.param_names_gaussian elif kwargs_key == 'kwargs_background': return self.param_names_background elif kwargs_key == 'kwargs_distortion': return self.param_names_distortion else: raise KeyError(f'`{kwargs_key}` is not in the kwargs')
def _get_params(self, args, i, kwargs_key): """Getting the parameters.""" kwargs = {} kwargs_fixed_k = self._kwargs_fixed[kwargs_key] param_names = self.get_param_names_for_model(kwargs_key) for name in param_names: if name not in kwargs_fixed_k.keys(): if name == 'background': num_param = self.background_param_number elif name == 'mean': num_param = self.M elif name == 'a': num_param = self.M elif name == 'x0' or name == 'y0': num_param = self.M elif name in ['dilation_x', 'dilation_y', 'shear']: num_param = 2 # 2d order 2 polynomial without constant term else: num_param = 1 kwargs[name] = args[i:i + num_param] i += num_param else: kwargs[name] = kwargs_fixed_k[name] free_ind = self._kwargs_free_indices[kwargs_key][name] if len(free_ind) > 0: num_param = len(free_ind) kwargs[name] = kwargs[name].at[free_ind].set(args[i:i+num_param]) i += num_param return kwargs, i
[docs] def get_all_free_param_names(self, kwargs): args = self._param_names(kwargs, 'kwargs_moffat') args += self._param_names(kwargs, 'kwargs_gaussian') args += self._param_names(kwargs, 'kwargs_background') args += self._param_names(kwargs, 'kwargs_distortion') return args