mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Neptune.ai integration (#11937)
An option that turns on neptune.ai logging --report_to 'neptune' Additional ENV variables: NEPTUNE_PROJECT NEPTUNE_API_TOKEN NEPTUNE_RUN_NAME (optional) NEPTUNE_STOP_TIMEOUT (optional)
This commit is contained in:
parent
ae6ce28f31
commit
9996558bff
@ -105,6 +105,10 @@ def is_deepspeed_available():
|
|||||||
return importlib.util.find_spec("deepspeed") is not None
|
return importlib.util.find_spec("deepspeed") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_neptune_available():
|
||||||
|
return importlib.util.find_spec("neptune") is not None
|
||||||
|
|
||||||
|
|
||||||
def hp_params(trial):
|
def hp_params(trial):
|
||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
import optuna
|
import optuna
|
||||||
@ -921,10 +925,80 @@ class MLflowCallback(TrainerCallback):
|
|||||||
self._ml_flow.end_run()
|
self._ml_flow.end_run()
|
||||||
|
|
||||||
|
|
||||||
|
class NeptuneCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A :class:`~transformers.TrainerCallback` that sends the logs to `Neptune <https://neptune.ai>`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
assert (
|
||||||
|
is_neptune_available()
|
||||||
|
), "NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
|
||||||
|
import neptune.new as neptune
|
||||||
|
|
||||||
|
self._neptune = neptune
|
||||||
|
self._initialized = False
|
||||||
|
self._log_artifacts = False
|
||||||
|
|
||||||
|
def setup(self, args, state, model):
|
||||||
|
"""
|
||||||
|
Setup the Neptune integration.
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
NEPTUNE_PROJECT (:obj:`str`, `required`):
|
||||||
|
The project ID for neptune.ai account. Should be in format `workspace_name/project_name`
|
||||||
|
NEPTUNE_API_TOKEN (:obj:`str`, `required`):
|
||||||
|
API-token for neptune.ai account
|
||||||
|
NEPTUNE_CONNECTION_MODE (:obj:`str`, `optional`):
|
||||||
|
Neptune connection mode. `async` by default
|
||||||
|
NEPTUNE_RUN_NAME (:obj:`str`, `optional`):
|
||||||
|
The name of run process on Neptune dashboard
|
||||||
|
"""
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
self._neptune_run = self._neptune.init(
|
||||||
|
project=os.getenv("NEPTUNE_PROJECT"),
|
||||||
|
api_token=os.getenv("NEPTUNE_API_TOKEN"),
|
||||||
|
mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"),
|
||||||
|
name=os.getenv("NEPTUNE_RUN_NAME", None),
|
||||||
|
)
|
||||||
|
combined_dict = args.to_dict()
|
||||||
|
if hasattr(model, "config") and model.config is not None:
|
||||||
|
model_config = model.config.to_dict()
|
||||||
|
combined_dict = {**model_config, **combined_dict}
|
||||||
|
self._neptune_run["parameters"] = combined_dict
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model)
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs, model=None, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model)
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
for k, v in logs.items():
|
||||||
|
self._neptune_run[k].log(v, step=state.global_step)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
Environment:
|
||||||
|
NEPTUNE_STOP_TIMEOUT (:obj:`int`, `optional`):
|
||||||
|
Number of seconsds to wait for all Neptune.ai tracking calls to finish, before stopping the tracked
|
||||||
|
run. If not set it will wait for all tracking calls to finish.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stop_timeout = os.getenv("NEPTUNE_STOP_TIMEOUT")
|
||||||
|
stop_timeout = int(stop_timeout) if stop_timeout else None
|
||||||
|
self._neptune_run.stop(seconds=stop_timeout)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
INTEGRATION_TO_CALLBACK = {
|
INTEGRATION_TO_CALLBACK = {
|
||||||
"azure_ml": AzureMLCallback,
|
"azure_ml": AzureMLCallback,
|
||||||
"comet_ml": CometCallback,
|
"comet_ml": CometCallback,
|
||||||
"mlflow": MLflowCallback,
|
"mlflow": MLflowCallback,
|
||||||
|
"neptune": NeptuneCallback,
|
||||||
"tensorboard": TensorBoardCallback,
|
"tensorboard": TensorBoardCallback,
|
||||||
"wandb": WandbCallback,
|
"wandb": WandbCallback,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user