fix(mlflow): check mlflow version to use the synchronous flag (#29195)

* fix(mlflow): check mlflow version to use the  flag

* fix indent

* add log_params async and fix quality
This commit is contained in:
cchen-dialpad 2024-02-23 00:19:51 -08:00 committed by GitHub
parent 2cc8cf6ce7
commit 4524494072
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -29,6 +29,7 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
import numpy as np
import packaging.version
from .. import __version__ as version
from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
@ -985,6 +986,12 @@ class MLflowCallback(TrainerCallback):
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
self._async_log = False
# "synchronous" flag is only available with mlflow version >= 2.8.0
# https://github.com/mlflow/mlflow/pull/9705
# https://github.com/mlflow/mlflow/releases/tag/v2.8.0
if packaging.version.parse(importlib.metadata.version("mlflow")) >= packaging.version.parse("2.8.0"):
self._async_log = True
logger.debug(
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
f" tags={self._nested_run}, tracking_uri={self._tracking_uri}"
@ -1023,7 +1030,12 @@ class MLflowCallback(TrainerCallback):
# MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items())
for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
if self._async_log:
self._ml_flow.log_params(
dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]), synchronous=False
)
else:
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
mlflow_tags = os.getenv("MLFLOW_TAGS", None)
if mlflow_tags:
mlflow_tags = json.loads(mlflow_tags)
@ -1047,7 +1059,11 @@ class MLflowCallback(TrainerCallback):
f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
"MLflow's log_metric() only accepts float and int types so we dropped this attribute."
)
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step, synchronous=False)
if self._async_log:
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step, synchronous=False)
else:
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero: