mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
fix(mlflow): check mlflow version to use the synchronous flag (#29195)
* fix(mlflow): check mlflow version to use the flag * fix indent * add log_params async and fix quality
This commit is contained in:
parent
2cc8cf6ce7
commit
4524494072
@ -29,6 +29,7 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
|
||||
from .. import __version__ as version
|
||||
from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
|
||||
@ -985,6 +986,12 @@ class MLflowCallback(TrainerCallback):
|
||||
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
|
||||
self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
|
||||
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
|
||||
self._async_log = False
|
||||
# "synchronous" flag is only available with mlflow version >= 2.8.0
|
||||
# https://github.com/mlflow/mlflow/pull/9705
|
||||
# https://github.com/mlflow/mlflow/releases/tag/v2.8.0
|
||||
if packaging.version.parse(importlib.metadata.version("mlflow")) >= packaging.version.parse("2.8.0"):
|
||||
self._async_log = True
|
||||
logger.debug(
|
||||
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
|
||||
f" tags={self._nested_run}, tracking_uri={self._tracking_uri}"
|
||||
@ -1023,7 +1030,12 @@ class MLflowCallback(TrainerCallback):
|
||||
# MLflow cannot log more than 100 values in one go, so we have to split it
|
||||
combined_dict_items = list(combined_dict.items())
|
||||
for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
|
||||
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
|
||||
if self._async_log:
|
||||
self._ml_flow.log_params(
|
||||
dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]), synchronous=False
|
||||
)
|
||||
else:
|
||||
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
|
||||
mlflow_tags = os.getenv("MLFLOW_TAGS", None)
|
||||
if mlflow_tags:
|
||||
mlflow_tags = json.loads(mlflow_tags)
|
||||
@ -1047,7 +1059,11 @@ class MLflowCallback(TrainerCallback):
|
||||
f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
|
||||
"MLflow's log_metric() only accepts float and int types so we dropped this attribute."
|
||||
)
|
||||
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step, synchronous=False)
|
||||
|
||||
if self._async_log:
|
||||
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step, synchronous=False)
|
||||
else:
|
||||
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if self._initialized and state.is_world_process_zero:
|
||||
|
Loading…
Reference in New Issue
Block a user