mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-22 22:09:23 +06:00
Mlflow integration callback (#8016)
* Add MLflow integration class Add integration code for MLflow in integrations.py along with the code that checks that MLflow is installed. * Add MLflowCallback import Add import of MLflowCallback in trainer.py * Handle model argument Allow the callback to handle model argument and store model config items as hyperparameters. * Log parameters to MLflow in batches MLflow cannot log more than a hundred parameters at once. Code added to split the parameters into batches of 100 items and log the batches one by one. * Fix style * Add docs on MLflow callback * Fix issue with unfinished runs The "fluent" api used in MLflow integration allows only one run to be active at any given moment. If the Trainer is disposed off and a new one is created, but the training is not finished, it will refuse to log the results when the next trainer is created. * Add MLflow integration class Add integration code for MLflow in integrations.py along with the code that checks that MLflow is installed. * Add MLflowCallback import Add import of MLflowCallback in trainer.py * Handle model argument Allow the callback to handle model argument and store model config items as hyperparameters. * Log parameters to MLflow in batches MLflow cannot log more than a hundred parameters at once. Code added to split the parameters into batches of 100 items and log the batches one by one. * Fix style * Add docs on MLflow callback * Fix issue with unfinished runs The "fluent" api used in MLflow integration allows only one run to be active at any given moment. If the Trainer is disposed off and a new one is created, but the training is not finished, it will refuse to log the results when the next trainer is created.
This commit is contained in:
parent
8be9cb0aef
commit
c48b16b8da
@ -20,6 +20,7 @@ By default a :class:`~transformers.Trainer` will use the following callbacks:
|
||||
or tensorboardX).
|
||||
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
|
||||
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
|
||||
- :class:`~transformers.integrations.MLflowCallback` if `mlflow <https://www.mlflow.org/>`__ is installed.
|
||||
|
||||
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
|
||||
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
|
||||
@ -46,6 +47,9 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
|
||||
.. autoclass:: transformers.integrations.WandbCallback
|
||||
:members: setup
|
||||
|
||||
.. autoclass:: transformers.integrations.MLflowCallback
|
||||
:members: setup
|
||||
|
||||
|
||||
TrainerCallback
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -53,6 +53,14 @@ except ImportError:
|
||||
except ImportError:
|
||||
_has_tensorboard = False
|
||||
|
||||
try:
|
||||
import mlflow # noqa: F401
|
||||
|
||||
_has_mlflow = True
|
||||
except ImportError:
|
||||
_has_mlflow = False
|
||||
|
||||
|
||||
# No transformer imports above this point
|
||||
|
||||
from .file_utils import is_torch_tpu_available
|
||||
@ -85,6 +93,10 @@ def is_ray_available():
|
||||
return _has_ray
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
return _has_mlflow
|
||||
|
||||
|
||||
def hp_params(trial):
|
||||
if is_optuna_available():
|
||||
if isinstance(trial, optuna.Trial):
|
||||
@ -408,3 +420,80 @@ class CometCallback(TrainerCallback):
|
||||
experiment = comet_ml.config.get_global_experiment()
|
||||
if experiment is not None:
|
||||
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
|
||||
|
||||
|
||||
class MLflowCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow
|
||||
<https://www.mlflow.org/>`__.
|
||||
"""
|
||||
|
||||
MAX_LOG_SIZE = 100
|
||||
|
||||
def __init__(self):
|
||||
assert _has_mlflow, "MLflow requires mlflow to be installed. Run `pip install mlflow`."
|
||||
self._initialized = False
|
||||
self._log_artifacts = False
|
||||
|
||||
def setup(self, args, state, model):
|
||||
"""
|
||||
Setup the optional MLflow integration.
|
||||
|
||||
Environment:
|
||||
HF_MLFLOW_LOG_ARTIFACTS (:obj:`str`, `optional`):
|
||||
Whether to use MLflow .log_artifact() facility to log artifacts.
|
||||
|
||||
This only makes sense if logging to a remote server, e.g. s3 or GCS.
|
||||
If set to `True` or `1`, will copy whatever is in TrainerArgument's output_dir
|
||||
to the local or remote artifact storage. Using it without a remote storage
|
||||
will just copy the files to your artifact location.
|
||||
"""
|
||||
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
|
||||
if log_artifacts in {"TRUE", "1"}:
|
||||
self._log_artifacts = True
|
||||
if state.is_world_process_zero:
|
||||
mlflow.start_run()
|
||||
combined_dict = args.to_dict()
|
||||
if hasattr(model, "config") and model.config is not None:
|
||||
model_config = model.config.to_dict()
|
||||
combined_dict = {**model_config, **combined_dict}
|
||||
# MLflow cannot log more than 100 values in one go, so we have to split it
|
||||
combined_dict_items = list(combined_dict.items())
|
||||
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
|
||||
mlflow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
|
||||
self._initialized = True
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
|
||||
def on_log(self, args, state, control, logs, model=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
if state.is_world_process_zero:
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
mlflow.log_metric(k, v, step=state.global_step)
|
||||
else:
|
||||
logger.warning(
|
||||
"Trainer is attempting to log a value of "
|
||||
'"%s" of type %s for key "%s" as a metric. '
|
||||
"MLflow's log_metric() only accepts float and "
|
||||
"int types so we dropped this attribute.",
|
||||
v,
|
||||
type(v),
|
||||
k,
|
||||
)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if self._initialized and state.is_world_process_zero:
|
||||
if self._log_artifacts:
|
||||
logger.info("Logging artifacts. This may take time.")
|
||||
mlflow.log_artifacts(args.output_dir)
|
||||
mlflow.end_run()
|
||||
|
||||
def __del__(self):
|
||||
# if the previous run is not terminated correctly, the fluent API will
|
||||
# not let you start a new run before the previous one is killed
|
||||
if mlflow.active_run is not None:
|
||||
mlflow.end_run(status="KILLED")
|
||||
|
@ -41,6 +41,7 @@ from .integrations import (
|
||||
default_hp_search_backend,
|
||||
hp_params,
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_tensorboard_available,
|
||||
@ -139,6 +140,11 @@ if is_comet_available():
|
||||
|
||||
DEFAULT_CALLBACKS.append(CometCallback)
|
||||
|
||||
if is_mlflow_available():
|
||||
from .integrations import MLflowCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(MLflowCallback)
|
||||
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user