diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 00ebaa29afc..0c67b299dd9 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -119,6 +119,10 @@ def is_mlflow_available(): return importlib.util.find_spec("mlflow") is not None +def is_dagshub_available(): + return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")] + + def is_fairscale_available(): return importlib.util.find_spec("fairscale") is not None @@ -522,6 +526,8 @@ def get_available_reporting_integrations(): integrations.append("azure_ml") if is_comet_available(): integrations.append("comet_ml") + if is_dagshub_available(): + integrations.append("dagshub") if is_mlflow_available(): integrations.append("mlflow") if is_neptune_available(): @@ -1045,6 +1051,55 @@ class MLflowCallback(TrainerCallback): self._ml_flow.end_run() +class DagsHubCallback(MLflowCallback): + """ + A [`TrainerCallback`] that logs to [DagsHub](https://dagshub.com/). + """ + + def __init__(self): + super().__init__() + if not is_dagshub_available(): + raise ImportError("DagsHubCallback requires dagshub to be installed. Run `pip install dagshub`.") + + from dagshub.upload import Repo + + self.Repo = Repo + + def setup(self, *args, **kwargs): + """ + Setup the DagsHub's Logging integration. + + Environment: + HF_DAGSHUB_LOG_ARTIFACTS (`str`, *optional*): + Whether to save the data and model artifacts for the experiment. Default to `False`. + """ + + self.log_artifacts = os.getenv("HF_DAGSHUB_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES + self.name = os.getenv("HF_DAGSHUB_MODEL_NAME") or "main" + self.remote = os.getenv("MLFLOW_TRACKING_URI") + self.repo = self.Repo( + owner=self.remote.split(os.sep)[-2], + name=self.remote.split(os.sep)[-1].split(".")[0], + branch=os.getenv("BRANCH") or "main", + ) + self.path = Path("artifacts") + + if self.remote is None: + raise RuntimeError( + "DagsHubCallback requires the `MLFLOW_TRACKING_URI` environment variable to be set. Did you run" + " `dagshub.init()`?" + ) + + super().setup(*args, **kwargs) + + def on_train_end(self, args, state, control, **kwargs): + if self.log_artifacts: + if getattr(self, "train_dataloader", None): + torch.save(self.train_dataloader.dataset, os.path.join(args.output_dir, "dataset.pt")) + + self.repo.directory(str(self.path)).add_dir(args.output_dir) + + class NeptuneMissingConfiguration(Exception): def __init__(self): super().__init__( @@ -1465,6 +1520,7 @@ INTEGRATION_TO_CALLBACK = { "wandb": WandbCallback, "codecarbon": CodeCarbonCallback, "clearml": ClearMLCallback, + "dagshub": DagsHubCallback, }