Added DagshubCallback (#21404)

* integrated logger

* bugifx

* added data

* bugfix

* model + state artifacts should log

* fixed paths

* i lied, trying again

* updated function call

* typo

this is painful :( what a stupid error

* typo

this is painful :( what a stupid error

* pivoted to adding a directory

* silly path bug

* multiple experiments

* migrated to getattr

* syntax fix

* syntax fix

* fixed repo pointer

* fixed path error

* added dataset if dataloader is present, uploaded artifacts

* variable in scope

* removed unnecessary line

* updated error type

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* trimmed unused variables, imports

* style formatting

* removed type conversion reliance

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* reverted accidental line deletion

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Jinen Setpal 2023-02-01 13:51:46 -05:00 committed by GitHub
parent 8d580779a3
commit 3fadb4b211
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -119,6 +119,10 @@ def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def is_dagshub_available():
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
def is_fairscale_available():
return importlib.util.find_spec("fairscale") is not None
@ -522,6 +526,8 @@ def get_available_reporting_integrations():
integrations.append("azure_ml")
if is_comet_available():
integrations.append("comet_ml")
if is_dagshub_available():
integrations.append("dagshub")
if is_mlflow_available():
integrations.append("mlflow")
if is_neptune_available():
@ -1045,6 +1051,55 @@ class MLflowCallback(TrainerCallback):
self._ml_flow.end_run()
class DagsHubCallback(MLflowCallback):
"""
A [`TrainerCallback`] that logs to [DagsHub](https://dagshub.com/).
"""
def __init__(self):
super().__init__()
if not is_dagshub_available():
raise ImportError("DagsHubCallback requires dagshub to be installed. Run `pip install dagshub`.")
from dagshub.upload import Repo
self.Repo = Repo
def setup(self, *args, **kwargs):
"""
Setup the DagsHub's Logging integration.
Environment:
HF_DAGSHUB_LOG_ARTIFACTS (`str`, *optional*):
Whether to save the data and model artifacts for the experiment. Default to `False`.
"""
self.log_artifacts = os.getenv("HF_DAGSHUB_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self.name = os.getenv("HF_DAGSHUB_MODEL_NAME") or "main"
self.remote = os.getenv("MLFLOW_TRACKING_URI")
self.repo = self.Repo(
owner=self.remote.split(os.sep)[-2],
name=self.remote.split(os.sep)[-1].split(".")[0],
branch=os.getenv("BRANCH") or "main",
)
self.path = Path("artifacts")
if self.remote is None:
raise RuntimeError(
"DagsHubCallback requires the `MLFLOW_TRACKING_URI` environment variable to be set. Did you run"
" `dagshub.init()`?"
)
super().setup(*args, **kwargs)
def on_train_end(self, args, state, control, **kwargs):
if self.log_artifacts:
if getattr(self, "train_dataloader", None):
torch.save(self.train_dataloader.dataset, os.path.join(args.output_dir, "dataset.pt"))
self.repo.directory(str(self.path)).add_dir(args.output_dir)
class NeptuneMissingConfiguration(Exception):
def __init__(self):
super().__init__(
@ -1465,6 +1520,7 @@ INTEGRATION_TO_CALLBACK = {
"wandb": WandbCallback,
"codecarbon": CodeCarbonCallback,
"clearml": ClearMLCallback,
"dagshub": DagsHubCallback,
}