diff --git a/docs/source/main_classes/callback.rst b/docs/source/main_classes/callback.rst
index f4160185bf2..16b1318b717 100644
--- a/docs/source/main_classes/callback.rst
+++ b/docs/source/main_classes/callback.rst
@@ -20,6 +20,7 @@ By default a :class:`~transformers.Trainer` will use the following callbacks:
or tensorboardX).
- :class:`~transformers.integrations.WandbCallback` if `wandb `__ is installed.
- :class:`~transformers.integrations.CometCallback` if `comet_ml `__ is installed.
+- :class:`~transformers.integrations.MLflowCallback` if `mlflow `__ is installed.
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
@@ -46,6 +47,9 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
.. autoclass:: transformers.integrations.WandbCallback
:members: setup
+.. autoclass:: transformers.integrations.MLflowCallback
+ :members: setup
+
TrainerCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py
index 93e1e6eab62..743b45a660d 100644
--- a/src/transformers/integrations.py
+++ b/src/transformers/integrations.py
@@ -53,6 +53,14 @@ except ImportError:
except ImportError:
_has_tensorboard = False
+try:
+ import mlflow # noqa: F401
+
+ _has_mlflow = True
+except ImportError:
+ _has_mlflow = False
+
+
# No transformer imports above this point
from .file_utils import is_torch_tpu_available
@@ -85,6 +93,10 @@ def is_ray_available():
return _has_ray
+def is_mlflow_available():
+ return _has_mlflow
+
+
def hp_params(trial):
if is_optuna_available():
if isinstance(trial, optuna.Trial):
@@ -408,3 +420,80 @@ class CometCallback(TrainerCallback):
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
+
+
+class MLflowCallback(TrainerCallback):
+ """
+ A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow
+ `__.
+ """
+
+ MAX_LOG_SIZE = 100
+
+ def __init__(self):
+ assert _has_mlflow, "MLflow requires mlflow to be installed. Run `pip install mlflow`."
+ self._initialized = False
+ self._log_artifacts = False
+
+ def setup(self, args, state, model):
+ """
+ Setup the optional MLflow integration.
+
+ Environment:
+ HF_MLFLOW_LOG_ARTIFACTS (:obj:`str`, `optional`):
+ Whether to use MLflow .log_artifact() facility to log artifacts.
+
+ This only makes sense if logging to a remote server, e.g. s3 or GCS.
+ If set to `True` or `1`, will copy whatever is in TrainerArgument's output_dir
+ to the local or remote artifact storage. Using it without a remote storage
+ will just copy the files to your artifact location.
+ """
+ log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
+ if log_artifacts in {"TRUE", "1"}:
+ self._log_artifacts = True
+ if state.is_world_process_zero:
+ mlflow.start_run()
+ combined_dict = args.to_dict()
+ if hasattr(model, "config") and model.config is not None:
+ model_config = model.config.to_dict()
+ combined_dict = {**model_config, **combined_dict}
+ # MLflow cannot log more than 100 values in one go, so we have to split it
+ combined_dict_items = list(combined_dict.items())
+ for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
+ mlflow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
+ self._initialized = True
+
+ def on_train_begin(self, args, state, control, model=None, **kwargs):
+ if not self._initialized:
+ self.setup(args, state, model)
+
+ def on_log(self, args, state, control, logs, model=None, **kwargs):
+ if not self._initialized:
+ self.setup(args, state, model)
+ if state.is_world_process_zero:
+ for k, v in logs.items():
+ if isinstance(v, (int, float)):
+ mlflow.log_metric(k, v, step=state.global_step)
+ else:
+ logger.warning(
+ "Trainer is attempting to log a value of "
+ '"%s" of type %s for key "%s" as a metric. '
+ "MLflow's log_metric() only accepts float and "
+ "int types so we dropped this attribute.",
+ v,
+ type(v),
+ k,
+ )
+
+ def on_train_end(self, args, state, control, **kwargs):
+ if self._initialized and state.is_world_process_zero:
+ if self._log_artifacts:
+ logger.info("Logging artifacts. This may take time.")
+ mlflow.log_artifacts(args.output_dir)
+ mlflow.end_run()
+
+ def __del__(self):
+ # if the previous run is not terminated correctly, the fluent API will
+ # not let you start a new run before the previous one is killed
+ if mlflow.active_run is not None:
+ mlflow.end_run(status="KILLED")
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index e1d1947ecbe..725c78b4098 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -41,6 +41,7 @@ from .integrations import (
default_hp_search_backend,
hp_params,
is_comet_available,
+ is_mlflow_available,
is_optuna_available,
is_ray_available,
is_tensorboard_available,
@@ -139,6 +140,11 @@ if is_comet_available():
DEFAULT_CALLBACKS.append(CometCallback)
+if is_mlflow_available():
+ from .integrations import MLflowCallback
+
+ DEFAULT_CALLBACKS.append(MLflowCallback)
+
if is_optuna_available():
import optuna