mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Added DagshubCallback (#21404)
* integrated logger * bugifx * added data * bugfix * model + state artifacts should log * fixed paths * i lied, trying again * updated function call * typo this is painful :( what a stupid error * typo this is painful :( what a stupid error * pivoted to adding a directory * silly path bug * multiple experiments * migrated to getattr * syntax fix * syntax fix * fixed repo pointer * fixed path error * added dataset if dataloader is present, uploaded artifacts * variable in scope * removed unnecessary line * updated error type Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * trimmed unused variables, imports * style formatting * removed type conversion reliance Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * reverted accidental line deletion --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
8d580779a3
commit
3fadb4b211
@ -119,6 +119,10 @@ def is_mlflow_available():
|
||||
return importlib.util.find_spec("mlflow") is not None
|
||||
|
||||
|
||||
def is_dagshub_available():
|
||||
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
|
||||
|
||||
|
||||
def is_fairscale_available():
|
||||
return importlib.util.find_spec("fairscale") is not None
|
||||
|
||||
@ -522,6 +526,8 @@ def get_available_reporting_integrations():
|
||||
integrations.append("azure_ml")
|
||||
if is_comet_available():
|
||||
integrations.append("comet_ml")
|
||||
if is_dagshub_available():
|
||||
integrations.append("dagshub")
|
||||
if is_mlflow_available():
|
||||
integrations.append("mlflow")
|
||||
if is_neptune_available():
|
||||
@ -1045,6 +1051,55 @@ class MLflowCallback(TrainerCallback):
|
||||
self._ml_flow.end_run()
|
||||
|
||||
|
||||
class DagsHubCallback(MLflowCallback):
|
||||
"""
|
||||
A [`TrainerCallback`] that logs to [DagsHub](https://dagshub.com/).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if not is_dagshub_available():
|
||||
raise ImportError("DagsHubCallback requires dagshub to be installed. Run `pip install dagshub`.")
|
||||
|
||||
from dagshub.upload import Repo
|
||||
|
||||
self.Repo = Repo
|
||||
|
||||
def setup(self, *args, **kwargs):
|
||||
"""
|
||||
Setup the DagsHub's Logging integration.
|
||||
|
||||
Environment:
|
||||
HF_DAGSHUB_LOG_ARTIFACTS (`str`, *optional*):
|
||||
Whether to save the data and model artifacts for the experiment. Default to `False`.
|
||||
"""
|
||||
|
||||
self.log_artifacts = os.getenv("HF_DAGSHUB_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
|
||||
self.name = os.getenv("HF_DAGSHUB_MODEL_NAME") or "main"
|
||||
self.remote = os.getenv("MLFLOW_TRACKING_URI")
|
||||
self.repo = self.Repo(
|
||||
owner=self.remote.split(os.sep)[-2],
|
||||
name=self.remote.split(os.sep)[-1].split(".")[0],
|
||||
branch=os.getenv("BRANCH") or "main",
|
||||
)
|
||||
self.path = Path("artifacts")
|
||||
|
||||
if self.remote is None:
|
||||
raise RuntimeError(
|
||||
"DagsHubCallback requires the `MLFLOW_TRACKING_URI` environment variable to be set. Did you run"
|
||||
" `dagshub.init()`?"
|
||||
)
|
||||
|
||||
super().setup(*args, **kwargs)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if self.log_artifacts:
|
||||
if getattr(self, "train_dataloader", None):
|
||||
torch.save(self.train_dataloader.dataset, os.path.join(args.output_dir, "dataset.pt"))
|
||||
|
||||
self.repo.directory(str(self.path)).add_dir(args.output_dir)
|
||||
|
||||
|
||||
class NeptuneMissingConfiguration(Exception):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@ -1465,6 +1520,7 @@ INTEGRATION_TO_CALLBACK = {
|
||||
"wandb": WandbCallback,
|
||||
"codecarbon": CodeCarbonCallback,
|
||||
"clearml": ClearMLCallback,
|
||||
"dagshub": DagsHubCallback,
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user