mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Neptune.ai integration (#11937)
An option that turns on neptune.ai logging --report_to 'neptune' Additional ENV variables: NEPTUNE_PROJECT NEPTUNE_API_TOKEN NEPTUNE_RUN_NAME (optional) NEPTUNE_STOP_TIMEOUT (optional)
This commit is contained in:
parent
ae6ce28f31
commit
9996558bff
@ -105,6 +105,10 @@ def is_deepspeed_available():
|
||||
return importlib.util.find_spec("deepspeed") is not None
|
||||
|
||||
|
||||
def is_neptune_available():
|
||||
return importlib.util.find_spec("neptune") is not None
|
||||
|
||||
|
||||
def hp_params(trial):
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
@ -921,10 +925,80 @@ class MLflowCallback(TrainerCallback):
|
||||
self._ml_flow.end_run()
|
||||
|
||||
|
||||
class NeptuneCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that sends the logs to `Neptune <https://neptune.ai>`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
assert (
|
||||
is_neptune_available()
|
||||
), "NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
|
||||
import neptune.new as neptune
|
||||
|
||||
self._neptune = neptune
|
||||
self._initialized = False
|
||||
self._log_artifacts = False
|
||||
|
||||
def setup(self, args, state, model):
|
||||
"""
|
||||
Setup the Neptune integration.
|
||||
|
||||
Environment:
|
||||
NEPTUNE_PROJECT (:obj:`str`, `required`):
|
||||
The project ID for neptune.ai account. Should be in format `workspace_name/project_name`
|
||||
NEPTUNE_API_TOKEN (:obj:`str`, `required`):
|
||||
API-token for neptune.ai account
|
||||
NEPTUNE_CONNECTION_MODE (:obj:`str`, `optional`):
|
||||
Neptune connection mode. `async` by default
|
||||
NEPTUNE_RUN_NAME (:obj:`str`, `optional`):
|
||||
The name of run process on Neptune dashboard
|
||||
"""
|
||||
if state.is_world_process_zero:
|
||||
self._neptune_run = self._neptune.init(
|
||||
project=os.getenv("NEPTUNE_PROJECT"),
|
||||
api_token=os.getenv("NEPTUNE_API_TOKEN"),
|
||||
mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"),
|
||||
name=os.getenv("NEPTUNE_RUN_NAME", None),
|
||||
)
|
||||
combined_dict = args.to_dict()
|
||||
if hasattr(model, "config") and model.config is not None:
|
||||
model_config = model.config.to_dict()
|
||||
combined_dict = {**model_config, **combined_dict}
|
||||
self._neptune_run["parameters"] = combined_dict
|
||||
self._initialized = True
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
|
||||
def on_log(self, args, state, control, logs, model=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
if state.is_world_process_zero:
|
||||
for k, v in logs.items():
|
||||
self._neptune_run[k].log(v, step=state.global_step)
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Environment:
|
||||
NEPTUNE_STOP_TIMEOUT (:obj:`int`, `optional`):
|
||||
Number of seconsds to wait for all Neptune.ai tracking calls to finish, before stopping the tracked
|
||||
run. If not set it will wait for all tracking calls to finish.
|
||||
"""
|
||||
try:
|
||||
stop_timeout = os.getenv("NEPTUNE_STOP_TIMEOUT")
|
||||
stop_timeout = int(stop_timeout) if stop_timeout else None
|
||||
self._neptune_run.stop(seconds=stop_timeout)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
INTEGRATION_TO_CALLBACK = {
|
||||
"azure_ml": AzureMLCallback,
|
||||
"comet_ml": CometCallback,
|
||||
"mlflow": MLflowCallback,
|
||||
"neptune": NeptuneCallback,
|
||||
"tensorboard": TensorBoardCallback,
|
||||
"wandb": WandbCallback,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user