Source code for tuner

import copy
import logging


[docs] class MLTuner: """Simple hyper parameter tuner Sample `search_space`: .. code-block:: python param_space_minimal_prm = { "dense_layers": [4, 8, 12], "dense_neurons":[256, 512, 768], "learning_rate": [0.001, 0.002], "regu1": [1e-8, 1e-7] } :param search_space: Dictionary defining the search space. :param progress_callback: Callback function that is called after each iteration with updated search space as parameter. """ def __init__(self, search_space=None, progress_callback=None): # XXX Get rid of search_space?! self.log = logging.getLogger("MLTuner") self.progress_callback = progress_callback if search_space is None: self.search_space = {} self.search_space["best_ev"] = 0 self.search_space["is_first"] = True self.search_space["progress"] = 0 else: self.search_space = search_space self.search_space["is_first"] = False
[docs] def tune(self, param_space, eval_func): """Tune hyper parameters Example parameter space: .. code-block:: python param_space = { "dense_layers": [4, 8, 12], "dense_neurons":[256, 512, 768], "learning_rate": [0.001, 0.002], "regu1": [1e-8, 1e-7] } `eval_func` is called with a dictionary of hyper parameters with exactly one value for each key, e.g.: .. code-block:: python params={ "dense_layers": 8, "dense_neurons": 256, "learning_rate": 0.001, "regu1": 1e-8 } :param param_space: Dictionary defining the search space. :param eval_func: Function that is called to evaluate the hyper parameters. """ if "best_params" not in self.search_space: self.search_space["best_params"] = {} for key in param_space: if key not in self.search_space["best_params"]: self.search_space["best_params"][key] = param_space[key][0] p_cnt = 0 for key in param_space: params = copy.deepcopy(self.search_space["best_params"]) vals = param_space[key] for val in vals: if self.search_space["is_first"] is False: if val == self.search_space["best_params"][key]: continue # Was already tested. else: self.search_space["is_first"] = False if p_cnt < self.search_space["progress"]: p_cnt += 1 self.log.debug(f"Fast forwarding: {key} {val}") continue else: p_cnt += 1 self.search_space["progress"] += 1 params[key] = val self.log.debug(f"Testing: {key}={val} with {params}") ev = eval_func(params) self.log.debug(f"Eval: {ev}") if ev > self.search_space["best_ev"]: self.search_space["best_ev"] = ev self.search_space["best_params"] = copy.deepcopy(params) self.log.info(f"Best parameter set with ev={ev}: {params}") if self.progress_callback is not None: self.progress_callback(self.search_space) self.log.info( f"Best parameter set with {self.search_space['best_ev']} eval: {self.search_space['best_params']}" ) return self.search_space["best_params"]