# Copyright (c) 2021 Timothy Wolff-Piggott
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT
"""Integration of Optuna and Transformers."""
import logging
import os
from numbers import Number
from typing import Dict, Union
import mlflow
import transformers
from transformers import TrainerControl, TrainerState, TrainingArguments
from hpoflow.optuna_mlflow import OptunaMLflow
_logger = logging.getLogger(__name__)
[docs]class OptunaMLflowCallback(transformers.TrainerCallback):
"""Integration of Optuna and Transformers.
Class based on :class:`transformers.TrainerCallback`; integrates with OptunaMLflow to send
the logs to ``MLflow`` and ``Optuna`` during model training.
"""
def __init__(
self,
trial: OptunaMLflow,
log_training_args: bool = True,
log_model_config: bool = True,
):
"""Constructor.
Args:
trial: The OptunaMLflow object.
log_training_args: Whether to log all Transformers TrainingArguments as MLflow params.
log_model_config: Whether to log the Transformers model config as MLflow params.
"""
self._trial = trial
self._log_training_args = log_training_args
self._log_model_config = log_model_config
self._initialized = False
self._log_artifacts = False
[docs] def setup(
self,
args: TrainingArguments,
state: TrainerState,
model: Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel, None],
):
"""Setup the optional MLflow integration.
You can set the environment variable ``HF_MLFLOW_LOG_ARTIFACTS``. It is to use
:func:`mlflow.log_artifacts` to log artifacts. This only makes sense if logging to a remote
server, e.g. s3 or GCS. If set to ``True`` or ``1``, will copy whatever is in
TrainerArgument's output_dir to the local or remote artifact storage. Using it without a
remote storage will just copy the files to your artifact location.
"""
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
if log_artifacts in {"TRUE", "1", "YES"}:
self._log_artifacts = True # False is default
if state.is_world_process_zero:
combined_dict = {}
if self._log_training_args:
training_args = args.to_dict()
# create copy so keys do not change while iterating
keys = list(training_args.keys()).copy()
# add prefix
for key in keys:
training_args[f"hf_train_arg_{key}"] = training_args.pop(key)
_logger.debug("Logging training arguments. training_args: %s", training_args)
combined_dict.update(training_args)
if (
model is not None
and self._log_model_config
and hasattr(model, "config")
and model.config is not None # type: ignore
):
model_config = model.config.to_dict() # type: ignore
# create copy so keys do not change while iterating
keys = list(model_config.keys()).copy()
# add prefix
for key in keys:
model_config[f"hf_model_cfg_{key}"] = model_config.pop(key)
_logger.debug("Logging model config. model_config: %s", model_config)
combined_dict.update(model_config)
# TODO: call a DRY function in the mlflow module
# remove params that are too long for MLflow
for name, value in list(combined_dict.items()):
# internally, all values are converted to str in MLflow
if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH:
_logger.warning(
"Trainer is attempting to log a value of "
"'%s' for key '%s' as a parameter. "
"MLflow's log_param() only accepts values no longer than "
"250 characters so we dropped this attribute.",
value,
name,
)
del combined_dict[name]
# TODO: call a DRY function in the mlflow module
# MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items())
for i in range(
0, len(combined_dict_items), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
):
self._trial.log_params(
dict(
combined_dict_items[
i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH
]
)
)
self._initialized = True
[docs] def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel] = None,
**kwargs,
) -> None:
"""Event called at the beginning of training.
Call setup if not yet initialized.
"""
if not self._initialized:
self.setup(args, state, model)
[docs] def on_log( # type: ignore
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: Dict[str, Number],
model: Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel] = None,
**kwargs,
):
"""Event called after logging the last logs.
Log all metrics from Transformers logs as MLflow metrics at the appropriate step.
"""
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
metrics_to_log: Dict[str, float] = {}
for k, v in logs.items():
if isinstance(v, (int, float)):
metrics_to_log[k] = v
else:
_logger.warning(
"Trainer is attempting to log a value of "
"'%s' of type %s for key '%s' as a metric. "
"MLflow's log_metric() only accepts float and "
"int types so we dropped this attribute.",
v,
type(v),
k,
)
self._trial.log_metrics(metrics_to_log, step=state.global_step)
[docs] def on_train_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
"""Event called at the end of training.
Log the training output as MLflow artifacts if logging artifacts is enabled.
"""
if self._initialized and state.is_world_process_zero:
if self._log_artifacts:
_logger.info("Logging artifacts. This may take time.")
mlflow.log_artifacts(args.output_dir)