Source code for starred.optim.optimization

import time
import warnings

import jax
import numpy as np
import optax
from starred.optim.inference_base import FisherCovariance
from tqdm import tqdm
import jaxopt

from starred.optim.inference_base import InferenceBase


__all__ = ['Optimizer']

SUPPORTED_METHOD_OPTAX = ['adam', 'adabelief', 'radam']
SUPPORTED_METHOD_JAXOPT = ['BFGS', 'GradientDescent', 'LBFGS', 'l-bfgs-b', 'LevenbergMarquardt']
SUPPORTED_METHOD_SCIPYMINIMIZE = ['Nelder-Mead', 'Powell', 'CG', 'Newton-CG', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr',
                                  'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov']
SUPPORTED_BOUNDED_METHOD = ['l-bfgs-b']

# GRADIENT_METHOD = ['Nelder-Mead', 'Powell', 'CG' 'Newton-CG', 'trust-krylov', 'trust-exact', 'trust-constr']
# HVP_METHOD = ['Nelder-Mead', 'Powell', 'CG','Newton-CG', 'trust-constr', 'trust-krylov', 'trust-constr']


[docs] class Optimizer(InferenceBase): """ Class that handles optimization tasks, `i.e.`, finding best-fit point estimates of parameters It currently handles a subset of scipy.optimize.minimize routines, using first and second order derivatives when required. """ def __init__(self, loss_class, param_class, method='BFGS'): super().__init__(loss_class, param_class) if method not in SUPPORTED_METHOD_OPTAX + SUPPORTED_METHOD_JAXOPT + SUPPORTED_METHOD_SCIPYMINIMIZE : raise NotImplementedError(f'Minimisation method {method} is not supported yet.') else: self.method = method self._metrics = MinimizeMetrics(self.loss, self.method) @property def loss_history(self): """Returns the loss history.""" if len(self._metrics.loss_history) == 0: raise ValueError("You must run the optimizer at least once to access the history") return self._metrics.loss_history @property def param_history(self): """Returns the parameter history.""" if len(self._metrics.param_history) == 0: raise ValueError("You must run the optimizer at least once to access the history") return self._metrics.param_history
[docs] def minimize(self, restart_from_init=True, **kwargs_optimiser): """Minimizes the loss function and returns the best fit.""" init_params = self._param.current_values(as_kwargs=False, restart=restart_from_init, copy=True) start = time.time() if self.method in SUPPORTED_METHOD_OPTAX: best_fit, extra_fields = self._run_optax(init_params, **kwargs_optimiser) elif self.method in SUPPORTED_METHOD_JAXOPT + SUPPORTED_METHOD_SCIPYMINIMIZE: best_fit, extra_fields = self._run_jaxopt(init_params, **kwargs_optimiser) else: raise NotImplementedError(f'Minimization method {self.method} is not supported yet.') runtime = time.time() - start self._param.set_best_fit(best_fit) logL_best_fit = self.loss(best_fit) return best_fit, logL_best_fit, extra_fields, runtime
def _run_jaxopt(self, params, **kwargs_optimiser): extra_kwargs = {} if self.method in SUPPORTED_BOUNDED_METHOD: extra_kwargs['bounds'] = self.parameters.get_bounds() else: warnings.warn('You are using an unconstrained optimiser. Bounds are ignored.') if self.method == 'BFGS': solver = jaxopt.BFGS(fun=self.loss, **kwargs_optimiser) elif self.method == 'GradientDescent': solver = jaxopt.GradientDescent(fun=self.loss, **kwargs_optimiser) elif self.method == 'LBFGS': solver = jaxopt.LBFGS(fun=self.loss, **kwargs_optimiser) elif self.method in SUPPORTED_METHOD_SCIPYMINIMIZE : solver = jaxopt.ScipyMinimize(fun=self.loss, method=self.method, callback=self._metrics, **kwargs_optimiser) elif self.method == 'l-bfgs-b': solver = jaxopt.ScipyBoundedMinimize(fun=self.loss, method="l-bfgs-b", callback=self._metrics, **kwargs_optimiser) elif self.method == 'LevenbergMarquardt': solver = jaxopt.LevenbergMarquardt(residual_fun=lambda x: jax.numpy.array([self.loss(x)]), **kwargs_optimiser) res = solver.run(params, **extra_kwargs) extra_fields = {'result_class': res, 'stat':res.state, 'loss_history':self._metrics.loss_history} return res.params, extra_fields def _run_optax(self, params, max_iterations=100, min_iterations=None, init_learning_rate=1e-2, schedule_learning_rate=True, restart_from_init=False, stop_at_loss_increase=False, progress_bar=True, return_param_history=False, decay_rate=0.99): if min_iterations is None: min_iterations = max_iterations if schedule_learning_rate is True: # Exponential decay of the learning rate scheduler = optax.exponential_decay( init_value=init_learning_rate, decay_rate=decay_rate, transition_steps=max_iterations) if self.method == 'adabelief': scale_algo = optax.scale_by_belief() elif self.method == 'radam': scale_algo = optax.scale_by_radam() elif self.method == 'adam': scale_algo = optax.scale_by_adam() else: raise NotImplementedError(f"Optax algorithm '{self.method}' is not supported") # Combining gradient transforms using `optax.chain` optim = optax.chain( # optax.clip_by_global_norm(1.0), # clip by the gradient by the global norm scale_algo, # use the updates from the chosen optimizer optax.scale_by_schedule(scheduler), # Use the learning rate from the scheduler optax.scale(-1.) # because gradient *descent* ) else: if self.method == 'adabelief': optim = optax.adabelief(init_learning_rate) elif self.method== 'radam': optim = optax.radam(init_learning_rate) elif self.method == 'adam': optim = optax.adam(init_learning_rate) else: raise NotImplementedError(f"Optax algorithm '{self.method}' is not supported") # Initialise optimizer state opt_state = optim.init(params) prev_params, prev_loss_val = params, 1e10 @jax.jit def gradient_step(params, opt_state): # loss_val, grads = jax.value_and_grad(self._loss)(params) loss_val, grads = self.value_and_gradient(params) updates, opt_state = optim.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return params, opt_state, loss_val # Gradient descent loop for i in self._for_loop(range(int(max_iterations)), progress_bar, total=int(max_iterations), desc=f"optax.{self.method}"): params, opt_state, loss_val = gradient_step(params, opt_state) if stop_at_loss_increase and i > min_iterations and loss_val > prev_loss_val: #pragma: no cover params, loss_val = prev_params, prev_loss_val break else: self._metrics.loss_history.append(loss_val) prev_params, prev_loss_val = params, loss_val if return_param_history is True: self._metrics.param_history.append(params) best_fit = params extra_fields = {'loss_history': np.array( self._metrics.loss_history)} if return_param_history is True: extra_fields['param_history'] = self._metrics.param_history return best_fit, extra_fields @staticmethod def _for_loop(iterable, progress_bar_bool, **tqdm_kwargs): if progress_bar_bool is True: return tqdm(iterable, **tqdm_kwargs) else: return iterable
[docs] def compute_fisher_matrix(self): if len(self._param.best_fit_values(as_kwargs=False)) > 500: warnings.warn('Computing the Fisher matrix for more than 500 dimensions will blow up your memory. ' 'You should reduce the dimensionnality of your model') fish = FisherCovariance(self._param, self) fish.compute_fisher_information() return fish
class MinimizeMetrics(object): """Simple callable class used as callback in ``scipy.optimize.minimize`` method.""" def __init__(self, func, method): self.loss_history = [] self.param_history = [] self._func = func if method == 'trust-constr': self._call = None #bug fix for now, jaxopt does not support 2 argument callback for now : https://github.com/google/jaxopt/issues/428 # self._call = self._call_2args else: self._call = self._call_1arg def __call__(self, *args, **kwargs): return self._call(*args, **kwargs) def _call_1arg(self, x): self.loss_history.append(self._func(x)) self.param_history.append(x) def _call_2args(self, x, state): # Input state parameter is necessary for 'trust-constr' method # You can use it to stop execution early by returning True self.loss_history.append(self._func(x)) self.param_history.append(x)