mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Make MLFlow version detection more robust and handles mlflow-skinny (#29957)
* Make MLFlow version detection more robust and handles mlflow-skinny * Make function name more clear and refactor the logic * Further refactor
This commit is contained in:
parent
a907a903d6
commit
836e88caee
@ -131,13 +131,6 @@ def is_mlflow_available():
|
||||
return importlib.util.find_spec("mlflow") is not None
|
||||
|
||||
|
||||
def get_mlflow_version():
|
||||
try:
|
||||
return importlib.metadata.version("mlflow")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
return importlib.metadata.version("mlflow-skinny")
|
||||
|
||||
|
||||
def is_dagshub_available():
|
||||
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
|
||||
|
||||
@ -1005,12 +998,12 @@ class MLflowCallback(TrainerCallback):
|
||||
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
|
||||
self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
|
||||
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
|
||||
self._async_log = False
|
||||
|
||||
# "synchronous" flag is only available with mlflow version >= 2.8.0
|
||||
# https://github.com/mlflow/mlflow/pull/9705
|
||||
# https://github.com/mlflow/mlflow/releases/tag/v2.8.0
|
||||
if packaging.version.parse(get_mlflow_version()) >= packaging.version.parse("2.8.0"):
|
||||
self._async_log = True
|
||||
self._async_log = packaging.version.parse(self._ml_flow.__version__) >= packaging.version.parse("2.8.0")
|
||||
|
||||
logger.debug(
|
||||
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
|
||||
f" tags={self._nested_run}, tracking_uri={self._tracking_uri}"
|
||||
|
Loading…
Reference in New Issue
Block a user