This module offers tools for Optuna.


Use pip to install the necessary dependencies for this module: pip install mltb2[optuna]

class mltb2.optuna.SignificanceRepeatedTrainingPruner(alpha: float = 0.1, n_warmup_steps: int = 4)[source]

Bases: BasePruner

Optuna pruner which uses statistical significance as an heuristic for decision-making.

This is an Optuna Pruner which uses statistical significance as an heuristic for decision-making. It prunes repeated trainings like in a cross validation. As the test method a t-test is used. Our experiments have shown that an aplha value between 0.3 and 0.4 is reasonable.

Optuna's standard pruners assume that you only adjust the model once per hyperparameter set. Those pruners work on the basis of intermediate results. For example, once per epoch. In contrast, this pruner does not work on intermediate results but on the results of a cross validation or more precisely the results of the individual folds.

Below is a minimalist example:

from mltb2.optuna import SignificanceRepeatedTrainingPruner
import logging
import numpy as np
import optuna
from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

# configure the logger to see the debug output from the pruner

dataset = load_iris()

x, y = dataset['data'], dataset['target']

def train(trial):
    parameter = {
        'min_samples_split': trial.suggest_int('min_samples_split', 2, 20),
        'n_estimators': trial.suggest_int('n_estimators', 20, 100),

    validation_result_list = []

    skf = StratifiedKFold(n_splits=10)
    for fold_index, (train_index, val_index) in enumerate(skf.split(x, y)):
        X_train, X_val = x[train_index], x[val_index]
        y_train, y_val = y[train_index], y[val_index]

        rf = RandomForestClassifier(**parameter)
        rf.fit(X_train, y_train)
        y_pred = rf.predict(X_val)

        acc = accuracy_score(y_val, y_pred)

        # report result of this fold
        trial.report(acc, fold_index)

        # check if we should prune
        if trial.should_prune():
            # prune here - we are done with this CV

    return np.mean(validation_result_list)

study = optuna.create_study(
    # storage="sqlite:///optuna.db",  # we use in-memory storage here
    # add pruner to optuna

study.optimize(train, n_trials=10)
  • alpha (float) – The alpha level for the statistical significance test. The larger this value is, the more aggressively this pruner works. The smaller this value is, the stronger the statistical difference between the two distributions must be for Optuna to prune. alpha must be 0 < alpha < 1. Our experiments have shown that an aplha value between 0.3 and 0.4 is reasonable.

  • n_warmup_steps (int) – Pruning is disabled until the trial reaches or exceeds the given number of steps.

prune(study: Study, trial: FrozenTrial) bool[source]

Judge whether the trial should be pruned based on the reported values.

Note that this method is not supposed to be called by library users. Instead, optuna.trial.Trial.report() and optuna.trial.Trial.should_prune() provide user interfaces to implement pruning mechanism in an objective function.

  • study (Study) – Study object of the target study.

  • trial (FrozenTrial) – FrozenTrial object of the target trial. Take a copy before modifying this object.


A boolean value representing whether the trial should be pruned.

Return type: