# Copyright (c) 2021 Philip May
# Copyright (c) 2021 Philip May, Deutsche Telekom AG
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
"""Wrapper to log to Optuna and MLflow at the same time."""
import logging
import os
import platform
import sys
import textwrap
import traceback
import warnings
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import mlflow
import optuna
from mlflow.entities import RunStatus
from optuna.distributions import CategoricalChoiceType
from hpoflow.mlflow import (
check_repo_is_dirty,
normalize_mlflow_entry_name,
normalize_mlflow_entry_names_in_dict,
)
from hpoflow.utils import func_no_exception_caller
# import StudyDirection
# pylint: disable=ungrouped-imports
try:
from optuna.study._study_direction import StudyDirection
except ImportError:
from optuna._study_direction import StudyDirection # type: ignore
_logger = logging.getLogger(__name__)
_max_mlflow_tag_length = mlflow.utils.validation.MAX_TAG_VAL_LENGTH
[docs]class OptunaMLflow:
"""Wrapper to log to Optuna and MLflow at the same time."""
def __init__(
self,
tracking_uri: Optional[str] = None,
num_name_digits: int = 3,
enforce_clean_git: bool = False,
optuna_result_name: str = "optuna_result",
):
"""Constructor.
Args:
tracking_uri: The MLflow tracking URL. Defaults to ``None`` which logs to the default
locale folder ``./mlruns`` or uses the ``MLFLOW_TRACKING_URI`` environment variable
if it is available. Also see :func:`mlflow.set_tracking_uri`.
num_name_digits: Number of digits for the MLflow ``run_name``.
enforce_clean_git: Check and enforce that the GIT repository has no uncommitted changes
(see :meth:`git.repo.base.Repo.is_dirty`).
optuna_result_name: Name of the metric which is logged to MLflo and is returned by the
objective function.
"""
# TODO: add checks for num_name_digits and optuna_result_name
self._tracking_uri = tracking_uri
self._num_name_digits = num_name_digits
self._enforce_clean_git = enforce_clean_git
self._optuna_result_name = optuna_result_name
self._hostname: Optional[str] = None
[docs] def __call__(
self,
# we use a strange type annotation here
# see https://stackoverflow.com/questions/33533148/how-do-i-type-hint-a-method-with-the-type-of-the-enclosing-class # noqa: E501
func: Callable[[Union[optuna.trial.Trial, "OptunaMLflow"]], float],
) -> Callable[[optuna.trial.Trial], float]:
"""Returns the decorator for the Optuna objective function.
Args:
func: The optuna objective function for the decorator.
"""
@wraps(func)
def objective_decorator(trial: optuna.trial.Trial) -> float:
"""Decorator for the Optuna objective function."""
# we must do this here and not in __init__
# __init__ is only called once when decorator is applied
# pylint: disable=attribute-defined-outside-init
self._trial = trial
self._iter_metrics: Dict[str, List[float]] = {}
self._next_iter_num: int = 0
# check if GIT repo is clean
if self._enforce_clean_git:
check_repo_is_dirty()
# TODO: set a tag if it is clean or when not checked
try:
# set tracking_uri for MLflow
if self._tracking_uri is not None:
mlflow.set_tracking_uri(self._tracking_uri)
mlflow.set_experiment(self._trial.study.study_name)
digits_format_string = "{{:0{}d}}".format(self._num_name_digits)
mlflow.start_run(run_name=digits_format_string.format(self._trial.number))
except Exception as e:
error_msg = "Exception raised during MLflow communication! Exception: {}".format(e)
_logger.error(error_msg, exc_info=True)
warnings.warn(error_msg, RuntimeWarning)
_logger.info("Run %s started.", self._trial.number)
tag_dict = {
"hostname": self._get_hostname(),
"process_id": os.getpid(),
}
self.set_tags(tag_dict)
try:
# call objective function
result = func(self)
# log the result to MLflow but not optuna
self.log_metric(self._optuna_result_name, result, optuna_log=False)
# extract and set tags from trial
tags = {}
# Set direction and convert it to str and remove the common prefix.
study_direction = self._trial.study.direction
if isinstance(study_direction, StudyDirection):
tags["direction"] = str(study_direction).rsplit(".", maxsplit=1)[-1]
distributions = {
(k + "_distribution"): str(v) for (k, v) in self._trial.distributions.items()
}
tags.update(distributions)
self.set_tags(tags, optuna_log=False)
# end run
self._end_run(RunStatus.to_string(RunStatus.FINISHED))
_logger.info("Run finished.")
return result
except (Exception, KeyboardInterrupt) as e:
error_msg = "Exception raised while executing Optuna trial! Exception: {}".format(
e
)
_logger.error(error_msg, exc_info=True)
# log exception info to Optuna and MLflow as a tag
exc_type, exc_value, exc_traceback = sys.exc_info()
exc_text = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
self.set_tag("exception", exc_text)
if exc_type is KeyboardInterrupt:
self._end_run(RunStatus.to_string(RunStatus.KILLED))
_logger.info("Run killed.")
else:
self._end_run(RunStatus.to_string(RunStatus.FAILED))
_logger.info("Run failed.")
warnings.warn(error_msg, RuntimeWarning)
raise # raise exception again
return objective_decorator
#####################################
# MLflow wrapper functions
#####################################
[docs] def log_metric(
self, key: str, value: float, step: Optional[int] = None, optuna_log: Optional[bool] = True
) -> None:
"""Log a metric under the current run.
Wrapper of the corresponding MLflow function (see :func:`mlflow.log_metric`). The data is
logged to MLflow and also added to Optuna as a user attribute (see
:meth:`optuna.trial.Trial.set_user_attr`).
Args:
key: x
value: x
step: x
optuna_log: If ``False`` this is not logged to Optuna. This is an internal parameter
that should be ignored by the API user.
"""
if optuna_log:
self._trial.set_user_attr(key, value)
_logger.info("Metric: %s: %s at step: %s", key, value, step)
func_no_exception_caller(
mlflow.log_metric, normalize_mlflow_entry_name(key), value, step=None
)
[docs] def log_metrics(
self,
metrics: Dict[str, float],
step: Optional[int] = None,
optuna_log: Optional[bool] = True,
) -> None:
"""Log multiple metrics for the current run.
Wrapper of the corresponding MLflow function (see :func:`mlflow.log_metrics`). The data is
logged to MLflow and also added to Optuna as a user attribute (see
:meth:`optuna.trial.Trial.set_user_attr`).
Args:
metrics: x
step: x
optuna_log: If ``False`` this is not logged to Optuna. This is an internal parameter
that should be ignored by the API user.
"""
for key, value in metrics.items():
if optuna_log:
self._trial.set_user_attr(key, value)
_logger.info("Metric: %s: %s at step: %s", key, value, step)
func_no_exception_caller(
mlflow.log_metrics, normalize_mlflow_entry_names_in_dict(metrics), step=step
)
[docs] def log_param(self, key: str, value: Any, optuna_log: Optional[bool] = True) -> None:
"""Log a parameter under the current run.
Wrapper of the corresponding MLflow function (see :func:`mlflow.log_param`). The data is
logged to MLflow and also added to Optuna as a user attribute (see
:meth:`optuna.trial.Trial.set_user_attr`).
Args:
key: x
value: x
optuna_log: If ``False`` this is not logged to Optuna. This is an internal parameter
that should be ignored by the API user.
"""
if optuna_log:
self._trial.set_user_attr(key, value)
_logger.info("Param: %s: %s", key, value)
func_no_exception_caller(mlflow.log_param, normalize_mlflow_entry_name(key), value)
[docs] def log_params(self, params: Dict[str, Any]) -> None:
"""Log a batch of params for the current run.
Wrapper of the corresponding MLflow function (see :func:`mlflow.log_params`). The data is
logged to MLflow and also added to Optuna as a user attribute (see
:meth:`optuna.trial.Trial.set_user_attr`).
"""
for key, value in params.items():
self._trial.set_user_attr(key, value)
_logger.info("Param: %s: %s", key, value)
func_no_exception_caller(mlflow.log_params, normalize_mlflow_entry_names_in_dict(params))
[docs] def set_tag(self, key: str, value: Any, optuna_log: Optional[bool] = True) -> None:
"""Set a tag under the current run.
Wrapper of the corresponding MLflow function (see :func:`mlflow.set_tag`). The data is
logged to MLflow and also added to Optuna as a user attribute (see
:meth:`optuna.trial.Trial.set_user_attr`).
Args:
key: x
value: x
optuna_log: If ``False`` this is not logged to Optuna. This is an internal parameter
that should be ignored by the API user.
"""
if optuna_log:
self._trial.set_user_attr(key, value)
_logger.info("Tag: %s: %s", key, value)
value = str(value) # make sure it is a string
if len(value) > _max_mlflow_tag_length:
value = textwrap.shorten(value, _max_mlflow_tag_length)
func_no_exception_caller(mlflow.set_tag, normalize_mlflow_entry_name(key), value)
[docs] def log_iter(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
"""Log an iteration or a fold as a nested run (see :func:`mlflow.log_metrics`).
The data is logged to MLflow and also added to Optuna as a user attribute (see
:meth:`optuna.trial.Trial.set_user_attr`).
"""
for key, value in metrics.items():
value_list: List[float] = self._iter_metrics.get(key, [])
value_list.append(value)
self._iter_metrics[key] = value_list
self._trial.set_user_attr("{}_iter".format(key), value_list)
_logger.info("Iteration metric: %s: %s at step: %s", key, value, step)
digits_format_string = "{{:0{0}d}}-{{:0{0}d}}".format(self._num_name_digits)
if step is None:
step = self._next_iter_num
self._next_iter_num += 1
func_no_exception_caller(
self._log_iter,
run_name=digits_format_string.format(self._trial.number, step),
metrics=metrics,
step=step,
)
[docs] def _log_iter(self, run_name: str, metrics: Dict[str, float], step: int):
"""Log an iteration or a fold as a nested run (see :func:`mlflow.log_metrics`).
The data is logged only to MLflow and not to Optuna.
"""
with mlflow.start_run(run_name=run_name, nested=True):
self.log_metrics(metrics, step=step, optuna_log=False)
[docs] @staticmethod
def _end_run(status: str, exc_text: Optional[str] = None) -> None:
"""End the active MLflow run (see :func:`mlflow.end_run`).
Args:
status: The status of the run (see :class:`mlflow.entities.RunStatus`).
exc_text: x
"""
func_no_exception_caller(mlflow.end_run, status)
if exc_text is None:
_logger.info("Run finished with status: %s", status)
else:
_logger.error("Run finished with status: %s, exc_text: %s", status, exc_text)
#####################################
# util functions
#####################################
[docs] def _get_hostname(self) -> str:
"""Get the hostname."""
if self._hostname is None:
self._hostname = "unknown"
try:
self._hostname = platform.node()
except Exception as e:
warn_msg = "Exception while getting hostname! {}".format(e)
_logger.warning(warn_msg)
warnings.warn(warn_msg, RuntimeWarning)
return self._hostname
#####################################
# Optuna wrapper functions
#####################################
[docs] def report(self, value: float, step: int) -> None:
"""Report an objective function value for a given step.
Wrapper of the corresponding Optuna function (see :meth:`optuna.trial.Trial.report`).
Args:
value: A value returned from the evaluation.
step: Step of the trial (e.g., Epoch of neural network training). Note that pruners
assume that ``step`` starts at zero. For example,
"""
self._trial.report(value, step)
[docs] def should_prune(self) -> bool:
"""Suggest whether the trial should be pruned or not.
Wrapper of the corresponding Optuna function (see :meth:`optuna.trial.Trial.should_prune`).
"""
return self._trial.should_prune()
[docs] def suggest_categorical(
self, name: str, choices: Sequence[CategoricalChoiceType]
) -> CategoricalChoiceType:
"""Suggest a value for the categorical parameter.
Wrapper of the corresponding Optuna function (see
:meth:`optuna.trial.Trial.suggest_categorical`).
Args:
name: A parameter name.
choices: Parameter value candidates.
"""
result = self._trial.suggest_categorical(name, choices)
self.log_param(name, result, optuna_log=False)
return result
[docs] def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int:
"""Suggest a value for the integer parameter.
Wrapper of the corresponding Optuna function (see :meth:`optuna.trial.Trial.suggest_int`).
Args:
name: A parameter name.
low: Lower endpoint of the range of suggested values. ``low`` is included in the range.
high: Upper endpoint of the range of suggested values. ``high`` is included in the
range.
step: A step of discretization.
log: A flag to sample the value from the log domain or not.
"""
result = self._trial.suggest_int(name, low, high, step, log)
self.log_param(name, result, optuna_log=False)
return result