diff --git a/docs/source/main_classes/callback.rst b/docs/source/main_classes/callback.rst index f4160185bf2..16b1318b717 100644 --- a/docs/source/main_classes/callback.rst +++ b/docs/source/main_classes/callback.rst @@ -20,6 +20,7 @@ By default a :class:`~transformers.Trainer` will use the following callbacks: or tensorboardX). - :class:`~transformers.integrations.WandbCallback` if `wandb `__ is installed. - :class:`~transformers.integrations.CometCallback` if `comet_ml `__ is installed. +- :class:`~transformers.integrations.MLflowCallback` if `mlflow `__ is installed. The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the :class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that @@ -46,6 +47,9 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the .. autoclass:: transformers.integrations.WandbCallback :members: setup +.. autoclass:: transformers.integrations.MLflowCallback + :members: setup + TrainerCallback ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 93e1e6eab62..743b45a660d 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -53,6 +53,14 @@ except ImportError: except ImportError: _has_tensorboard = False +try: + import mlflow # noqa: F401 + + _has_mlflow = True +except ImportError: + _has_mlflow = False + + # No transformer imports above this point from .file_utils import is_torch_tpu_available @@ -85,6 +93,10 @@ def is_ray_available(): return _has_ray +def is_mlflow_available(): + return _has_mlflow + + def hp_params(trial): if is_optuna_available(): if isinstance(trial, optuna.Trial): @@ -408,3 +420,80 @@ class CometCallback(TrainerCallback): experiment = comet_ml.config.get_global_experiment() if experiment is not None: experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers") + + +class MLflowCallback(TrainerCallback): + """ + A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow + `__. + """ + + MAX_LOG_SIZE = 100 + + def __init__(self): + assert _has_mlflow, "MLflow requires mlflow to be installed. Run `pip install mlflow`." + self._initialized = False + self._log_artifacts = False + + def setup(self, args, state, model): + """ + Setup the optional MLflow integration. + + Environment: + HF_MLFLOW_LOG_ARTIFACTS (:obj:`str`, `optional`): + Whether to use MLflow .log_artifact() facility to log artifacts. + + This only makes sense if logging to a remote server, e.g. s3 or GCS. + If set to `True` or `1`, will copy whatever is in TrainerArgument's output_dir + to the local or remote artifact storage. Using it without a remote storage + will just copy the files to your artifact location. + """ + log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() + if log_artifacts in {"TRUE", "1"}: + self._log_artifacts = True + if state.is_world_process_zero: + mlflow.start_run() + combined_dict = args.to_dict() + if hasattr(model, "config") and model.config is not None: + model_config = model.config.to_dict() + combined_dict = {**model_config, **combined_dict} + # MLflow cannot log more than 100 values in one go, so we have to split it + combined_dict_items = list(combined_dict.items()) + for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE): + mlflow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE])) + self._initialized = True + + def on_train_begin(self, args, state, control, model=None, **kwargs): + if not self._initialized: + self.setup(args, state, model) + + def on_log(self, args, state, control, logs, model=None, **kwargs): + if not self._initialized: + self.setup(args, state, model) + if state.is_world_process_zero: + for k, v in logs.items(): + if isinstance(v, (int, float)): + mlflow.log_metric(k, v, step=state.global_step) + else: + logger.warning( + "Trainer is attempting to log a value of " + '"%s" of type %s for key "%s" as a metric. ' + "MLflow's log_metric() only accepts float and " + "int types so we dropped this attribute.", + v, + type(v), + k, + ) + + def on_train_end(self, args, state, control, **kwargs): + if self._initialized and state.is_world_process_zero: + if self._log_artifacts: + logger.info("Logging artifacts. This may take time.") + mlflow.log_artifacts(args.output_dir) + mlflow.end_run() + + def __del__(self): + # if the previous run is not terminated correctly, the fluent API will + # not let you start a new run before the previous one is killed + if mlflow.active_run is not None: + mlflow.end_run(status="KILLED") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e1d1947ecbe..725c78b4098 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -41,6 +41,7 @@ from .integrations import ( default_hp_search_backend, hp_params, is_comet_available, + is_mlflow_available, is_optuna_available, is_ray_available, is_tensorboard_available, @@ -139,6 +140,11 @@ if is_comet_available(): DEFAULT_CALLBACKS.append(CometCallback) +if is_mlflow_available(): + from .integrations import MLflowCallback + + DEFAULT_CALLBACKS.append(MLflowCallback) + if is_optuna_available(): import optuna