diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 19bffe1f7a6..e05d1331f4d 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -105,6 +105,10 @@ def is_deepspeed_available(): return importlib.util.find_spec("deepspeed") is not None +def is_neptune_available(): + return importlib.util.find_spec("neptune") is not None + + def hp_params(trial): if is_optuna_available(): import optuna @@ -921,10 +925,80 @@ class MLflowCallback(TrainerCallback): self._ml_flow.end_run() +class NeptuneCallback(TrainerCallback): + """ + A :class:`~transformers.TrainerCallback` that sends the logs to `Neptune `. + """ + + def __init__(self): + assert ( + is_neptune_available() + ), "NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`." + import neptune.new as neptune + + self._neptune = neptune + self._initialized = False + self._log_artifacts = False + + def setup(self, args, state, model): + """ + Setup the Neptune integration. + + Environment: + NEPTUNE_PROJECT (:obj:`str`, `required`): + The project ID for neptune.ai account. Should be in format `workspace_name/project_name` + NEPTUNE_API_TOKEN (:obj:`str`, `required`): + API-token for neptune.ai account + NEPTUNE_CONNECTION_MODE (:obj:`str`, `optional`): + Neptune connection mode. `async` by default + NEPTUNE_RUN_NAME (:obj:`str`, `optional`): + The name of run process on Neptune dashboard + """ + if state.is_world_process_zero: + self._neptune_run = self._neptune.init( + project=os.getenv("NEPTUNE_PROJECT"), + api_token=os.getenv("NEPTUNE_API_TOKEN"), + mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"), + name=os.getenv("NEPTUNE_RUN_NAME", None), + ) + combined_dict = args.to_dict() + if hasattr(model, "config") and model.config is not None: + model_config = model.config.to_dict() + combined_dict = {**model_config, **combined_dict} + self._neptune_run["parameters"] = combined_dict + self._initialized = True + + def on_train_begin(self, args, state, control, model=None, **kwargs): + if not self._initialized: + self.setup(args, state, model) + + def on_log(self, args, state, control, logs, model=None, **kwargs): + if not self._initialized: + self.setup(args, state, model) + if state.is_world_process_zero: + for k, v in logs.items(): + self._neptune_run[k].log(v, step=state.global_step) + + def __del__(self): + """ + Environment: + NEPTUNE_STOP_TIMEOUT (:obj:`int`, `optional`): + Number of seconsds to wait for all Neptune.ai tracking calls to finish, before stopping the tracked + run. If not set it will wait for all tracking calls to finish. + """ + try: + stop_timeout = os.getenv("NEPTUNE_STOP_TIMEOUT") + stop_timeout = int(stop_timeout) if stop_timeout else None + self._neptune_run.stop(seconds=stop_timeout) + except AttributeError: + pass + + INTEGRATION_TO_CALLBACK = { "azure_ml": AzureMLCallback, "comet_ml": CometCallback, "mlflow": MLflowCallback, + "neptune": NeptuneCallback, "tensorboard": TensorBoardCallback, "wandb": WandbCallback, }