# Copyright (c) 2021 Philip May# This software is distributed under the terms of the MIT license# which is available at https://opensource.org/licenses/MIT"""Optuna only functionality."""importloggingimportnumpyasnpimportoptunafromoptuna.prunersimportBasePrunerfromoptuna.studyimportStudyDirectionfromscipyimportstats_logger=logging.getLogger(__name__)
[docs]classSignificanceRepeatedTrainingPruner(BasePruner):"""Pruner which uses statistical significance as an heuristic for decision-making. Pruner to use statistical significance to prune repeated trainings like in a cross validation. As the test method a t-test is used. Our experiments have shown that an ``aplha`` value between 0.3 and 0.4 is reasonable. """def__init__(self,alpha:float=0.1,n_warmup_steps:int=4)->None:"""Constructor. Args: alpha: The alpha level for the statistical significance test. The larger this value is, the more aggressively this pruner works. The smaller this value is, the stronger the statistical difference between the two distributions must be for Optuna to prune. ``alpha`` must be ``0 < alpha < 1``. n_warmup_steps: Pruning is disabled until the trial reaches or exceeds the given number of steps. """# input value checkifn_warmup_steps<0:raiseValueError("'n_warmup_steps' must not be negative! n_warmup_steps: {}".format(n_warmup_steps))ifalpha>=1:raiseValueError("'alpha' must be smaller than 1! {}".format(alpha))ifalpha<=0:raiseValueError("'alpha' must be greater than 0! {}".format(alpha))self.n_warmup_steps=n_warmup_stepsself.alpha=alpha
[docs]defprune(self,study:optuna.study.Study,trial:optuna.trial.FrozenTrial)->bool:"""Judge whether the trial should be pruned based on the reported values."""# get best tial - best trial is not available for first trialbest_trial=Nonetry:best_trial=study.best_trialexceptValueError:passifbest_trialisnotNone:trial_intermediate_values=list(trial.intermediate_values.values())_logger.debug("trial_intermediate_values: %s",trial_intermediate_values)# wait until the trial reaches or exceeds n_warmup_steps number of stepsiflen(trial_intermediate_values)>=self.n_warmup_steps:trial_mean=np.mean(trial_intermediate_values)best_trial_intermediate_values=list(best_trial.intermediate_values.values())best_trial_mean=np.mean(best_trial_intermediate_values)_logger.debug("trial_mean: %s",trial_mean)_logger.debug("best_trial_intermediate_values: %s",best_trial_intermediate_values)_logger.debug("best_trial_mean: %s",best_trial_mean)if(trial_mean<best_trial_meanandstudy.direction==StudyDirection.MAXIMIZE)or(trial_mean>best_trial_meanandstudy.direction==StudyDirection.MINIMIZE):pvalue=stats.ttest_ind(trial_intermediate_values,best_trial_intermediate_values,).pvalue_logger.debug("pvalue: %s",pvalue)ifpvalue<self.alpha:_logger.info("We prune this trial. pvalue: %s",pvalue)returnTrueelse:_logger.debug("This trial is better than best trial - we do not check for pruning.")else:_logger.debug("This trial did not reach n_warmup_steps - we do no checks.")returnFalse