Source code for starred.utils.parameters

from copy import deepcopy

import jax.numpy as jnp
import numpy as np

__all__ = ['Parameters']


[docs] class Parameters(object): """ Parameters class. """ def __init__(self, kwargs_init, kwargs_fixed, kwargs_up=None, kwargs_down=None): """ :param kwargs_init: dictionary with information on the initial values of the parameters :param kwargs_fixed: dictionary containing the fixed parameters :param kwargs_up: dictionary with information on the upper bounds of the parameters :param kwargs_down: dictionary with information on the lower bounds of the parameters """ self._kwargs_init = self.convert_in_jnp_array(kwargs_init) self._kwargs_fixed = self.convert_in_jnp_array(kwargs_fixed) if kwargs_up is not None: self._kwargs_up = self.convert_in_jnp_array(kwargs_up) else: self._kwargs_up = None if kwargs_down is not None: self._kwargs_down = self.convert_in_jnp_array(kwargs_down) else: self._kwargs_down = None self._kwargs_free_indices = self.get_nan_indices() # update the bounds in case we only have sub parts of our parameter-arrays # that are free: uku, ukd = self._update_bounds(self._kwargs_init, self._kwargs_fixed, self._kwargs_up, self._kwargs_down) self._kwargs_up = uku self._kwargs_down = ukd @property def optimized(self): """Checks whether a function is optimized.""" return hasattr(self, '_map_values')
[docs] def convert_in_jnp_array(self, kwargs): new_kwargs = {} for key in kwargs.keys(): new_kwargs[key] = {} for key2 in kwargs[key].keys(): new_kwargs[key][key2] = jnp.asarray(deepcopy(kwargs[key][key2])) return new_kwargs
[docs] def initial_values(self, as_kwargs=False, copy=False): """Returns the initial values of the parameters.""" if as_kwargs: return deepcopy(self._kwargs_init) if copy else self._kwargs_init else: return deepcopy(self._init_values) if copy else self._init_values
[docs] def current_values(self, as_kwargs=False, restart=False, copy=False): """Returns the current values of the parameters.""" if restart is True or not self.optimized: return self.initial_values(as_kwargs=as_kwargs, copy=copy) return self.best_fit_values(as_kwargs=as_kwargs, copy=copy)
[docs] def best_fit_values(self, as_kwargs=False, copy=False): """Maximum-a-postriori estimate.""" if as_kwargs: return deepcopy(self._kwargs_map) if copy else self._kwargs_map else: return deepcopy(self._map_values) if copy else self._map_values
[docs] def set_best_fit(self, args): """Sets the maximum-a-postriori estimate as the parameter values.""" self._map_values = args self._kwargs_map = self.args2kwargs(self._map_values)
def _update_arrays(self): self._init_values = self.kwargs2args(self._kwargs_init) self._kwargs_init = self.args2kwargs(self._init_values) # for updating missing fields self._num_params = len(self._init_values) if self.optimized: self._map_values = self.kwargs2args(self._kwargs_map) def _set_params(self, kwargs, kwargs_key): """Setting the parameters.""" args = [] kwargs_profile = kwargs[kwargs_key] kwargs_fixed_k = self._kwargs_fixed[kwargs_key] param_names = self.get_param_names_for_model(kwargs_key) for name in param_names: if name not in kwargs_fixed_k: if isinstance(kwargs_profile[name], list): args += kwargs_profile[name] elif isinstance(kwargs_profile[name], (np.ndarray, np.generic, jnp.ndarray, jnp.generic)): el = kwargs_profile[name].tolist() if hasattr(el, '__len__'): args += el else: args += [el] else: args += [kwargs_profile[name]] else: # add the indices in the self._kwargs_free_indices free_ind = self._kwargs_free_indices[kwargs_key][name] if len(free_ind) >= 1: args += kwargs_profile[name][free_ind].tolist() return args @staticmethod def _update_bounds(kwargs_init, kwargs_fixed, kwargs_up, kwargs_down): """ Called during initialization of this class. Updates the bounds (kwargs_up and kwargs_down) based on the fixed parameters (kwargs_fixed). NaN values in kwargs_fixed represent parameters that are still free to be optimized. The function adjusts the bounds to include only the indices where there are NaN values in kwargs_fixed. :param kwargs_init: Initial parameters :param kwargs_fixed: Fixed parameters, with NaN for values that are still optimized :param kwargs_up: Upper bounds for parameters :param kwargs_down: Lower bounds for parameters :return: Updated kwargs_up and kwargs_down """ for main_key in kwargs_fixed: for sub_key in kwargs_fixed[main_key]: fixed_values = np.array(kwargs_fixed[main_key][sub_key]) # Ensuring it's a numpy array if isinstance(fixed_values, np.ndarray): nan_indices = np.isnan(fixed_values) if np.any(nan_indices): if kwargs_up is not None: if not np.sum(nan_indices) == len( kwargs_up[main_key][sub_key]): # check if the kwargs have already been updated kwargs_up[main_key][sub_key] = np.array(kwargs_up[main_key][sub_key])[nan_indices] if kwargs_down is not None: if not np.sum(nan_indices) == len( kwargs_down[main_key][sub_key]): # check if the kwargs have already been updated kwargs_down[main_key][sub_key] = np.array(kwargs_down[main_key][sub_key])[nan_indices] return kwargs_up, kwargs_down
[docs] def get_bounds(self): """Returns the upper and lower bounds of the parameters.""" if self._kwargs_up is None or self._kwargs_down is None: return None else: list_down_limit = [] list_up_limit = [] for kwargs_key in self._kwargs_down.keys(): param_names = self.get_param_names_for_model(kwargs_key) for name in param_names: if not name in self._kwargs_fixed[kwargs_key].keys(): assert name in self._kwargs_up[kwargs_key].keys(), \ "Missing '%s' key in the kwargs_up['%s']" % (name, kwargs_key) assert name in self._kwargs_down[ kwargs_key].keys(), "Missing '%s' key in the kwargs_down['%s']" % (name, kwargs_key) up = self._kwargs_up[kwargs_key][name] down = self._kwargs_down[kwargs_key][name] else: # even if key in fixed, maybe we have nans - meaning # a subset is not fixed. free = self._kwargs_free_indices[kwargs_key][name] if len(free) == 0: continue up = self._kwargs_up[kwargs_key][name] down = self._kwargs_down[kwargs_key][name] if isinstance(down, list): list_down_limit += down elif isinstance(down, (np.ndarray, jnp.ndarray)): el = down.tolist() if hasattr(el, '__len__'): list_down_limit += el else: list_down_limit += [el] else: list_down_limit += [self._kwargs_down[kwargs_key][name]] if isinstance(up, list): list_up_limit += up elif isinstance(up, (np.ndarray, jnp.ndarray)): el = up.tolist() if hasattr(el, '__len__'): list_up_limit += el else: list_up_limit += [el] else: list_up_limit += [self._kwargs_up[kwargs_key][name]] return (jnp.array(list_down_limit).flatten(), jnp.array(list_up_limit).flatten())
[docs] def update_kwargs(self, kwargs_init=None, kwargs_fixed=None, kwargs_up=None, kwargs_down=None): """Updates the kwargs with provided values.""" if kwargs_init is not None: self._kwargs_init = kwargs_init if kwargs_fixed is not None: self._kwargs_fixed = kwargs_fixed self._kwargs_free_indices = self.get_nan_indices() if kwargs_init is not None: self._kwargs_up = kwargs_up if kwargs_init is not None: self._kwargs_down = kwargs_down
def _param_names(self, kwargs, kwargs_key): """Setting the parameters.""" names = [] kwargs_profile = kwargs[kwargs_key] kwargs_fixed_k = self._kwargs_fixed[kwargs_key] param_names = self.get_param_names_for_model(kwargs_key) for name in param_names: if name not in kwargs_fixed_k: value = kwargs_profile[name] if isinstance(value, (float, int)): names.append(name) elif len(value) == 1 and isinstance(value, (list, np.ndarray, np.generic, jnp.ndarray, jnp.generic)): names.append(name) elif isinstance(value, (list, np.ndarray, np.generic, jnp.ndarray, jnp.generic)): names += [f'{name}_{i}' for i in range(len(value))] else: names.append(name) else: num_free_param = len(self._kwargs_free_indices[kwargs_key][name]) if num_free_param == 1: names.append(f'{name}_{self._kwargs_free_indices[kwargs_key][name]}') else: names += [f'{name}_{i}' for i in self._kwargs_free_indices[kwargs_key][name]] return names
[docs] def get_nan_indices(self): kwargs_indice_free = {} for kwargs_key in self._kwargs_fixed.keys(): kwargs_indice_free[kwargs_key] = {} for key in self._kwargs_fixed[kwargs_key]: indices = jnp.where(jnp.isnan(self._kwargs_fixed[kwargs_key][key]))[0] kwargs_indice_free[kwargs_key][key] = indices return kwargs_indice_free