mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
feat(wandb): logging and configuration improvements (#10826)
* feat: ensure unique artifact id * feat: allow manual init * fix: simplify reinit logic * fix: no dropped value + immediate commits * fix: wandb use in sagemaker * docs: improve documenation and formatting * fix: typos * docs: improve formatting
This commit is contained in:
parent
b230181d41
commit
125ccead71
@ -240,34 +240,11 @@ Whenever you use `Trainer` or `TFTrainer` classes, your losses, evaluation metri
|
|||||||
|
|
||||||
Advanced configuration is possible by setting environment variables:
|
Advanced configuration is possible by setting environment variables:
|
||||||
|
|
||||||
<table>
|
| Environment Variable | Value |
|
||||||
<thead>
|
|---|---|
|
||||||
<tr>
|
| WANDB_LOG_MODEL | Log the model as artifact (log the model as artifact at the end of training (`false` by default) |
|
||||||
<th style="text-align:left">Environment Variables</th>
|
| WANDB_WATCH | one of `gradients` (default) to log histograms of gradients, `all` to log histograms of both gradients and parameters, or `false` for no histogram logging |
|
||||||
<th style="text-align:left">Options</th>
|
| WANDB_PROJECT | Organize runs by project |
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
<tr>
|
|
||||||
<td style="text-align:left">WANDB_LOG_MODEL</td>
|
|
||||||
<td style="text-align:left">Log the model as artifact at the end of training (<b>false</b> by default)</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td style="text-align:left">WANDB_WATCH</td>
|
|
||||||
<td style="text-align:left">
|
|
||||||
<ul>
|
|
||||||
<li><b>gradients</b> (default): Log histograms of the gradients</li>
|
|
||||||
<li><b>all</b>: Log histograms of gradients and parameters</li>
|
|
||||||
<li><b>false</b>: No gradient or parameter logging</li>
|
|
||||||
</ul>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td style="text-align:left">WANDB_PROJECT</td>
|
|
||||||
<td style="text-align:left">Organize runs by project</td>
|
|
||||||
</tr>
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
Set run names with `run_name` argument present in scripts or as part of `TrainingArguments`.
|
Set run names with `run_name` argument present in scripts or as part of `TrainingArguments`.
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ import io
|
|||||||
import json
|
import json
|
||||||
import numbers
|
import numbers
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -559,20 +558,12 @@ class WandbCallback(TrainerCallback):
|
|||||||
if has_wandb:
|
if has_wandb:
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
wandb.ensure_configured()
|
self._wandb = wandb
|
||||||
if wandb.api.api_key is None:
|
|
||||||
has_wandb = False
|
|
||||||
logger.warning(
|
|
||||||
"W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
|
|
||||||
)
|
|
||||||
self._wandb = None
|
|
||||||
else:
|
|
||||||
self._wandb = wandb
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
# log outputs
|
# log outputs
|
||||||
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
|
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
|
||||||
|
|
||||||
def setup(self, args, state, model, reinit, **kwargs):
|
def setup(self, args, state, model, **kwargs):
|
||||||
"""
|
"""
|
||||||
Setup the optional Weights & Biases (`wandb`) integration.
|
Setup the optional Weights & Biases (`wandb`) integration.
|
||||||
|
|
||||||
@ -581,7 +572,8 @@ class WandbCallback(TrainerCallback):
|
|||||||
|
|
||||||
Environment:
|
Environment:
|
||||||
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to log model as artifact at the end of training.
|
Whether or not to log model as artifact at the end of training. Use along with
|
||||||
|
`TrainingArguments.load_best_model_at_end` to upload best model.
|
||||||
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
|
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
|
||||||
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
|
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
|
||||||
logging or :obj:`"all"` to log gradients and parameters.
|
logging or :obj:`"all"` to log gradients and parameters.
|
||||||
@ -610,13 +602,19 @@ class WandbCallback(TrainerCallback):
|
|||||||
else:
|
else:
|
||||||
run_name = args.run_name
|
run_name = args.run_name
|
||||||
|
|
||||||
self._wandb.init(
|
if self._wandb.run is None:
|
||||||
project=os.getenv("WANDB_PROJECT", "huggingface"),
|
self._wandb.init(
|
||||||
config=combined_dict,
|
project=os.getenv("WANDB_PROJECT", "huggingface"),
|
||||||
name=run_name,
|
name=run_name,
|
||||||
reinit=reinit,
|
**init_args,
|
||||||
**init_args,
|
)
|
||||||
)
|
# add config parameters (run may have been created manually)
|
||||||
|
self._wandb.config.update(combined_dict, allow_val_change=True)
|
||||||
|
|
||||||
|
# define default x-axis (for latest wandb versions)
|
||||||
|
if getattr(self._wandb, "define_metric", None):
|
||||||
|
self._wandb.define_metric("train/global_step")
|
||||||
|
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
|
||||||
|
|
||||||
# keep track of model topology and gradients, unsupported on TPU
|
# keep track of model topology and gradients, unsupported on TPU
|
||||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||||
@ -628,23 +626,20 @@ class WandbCallback(TrainerCallback):
|
|||||||
if self._wandb is None:
|
if self._wandb is None:
|
||||||
return
|
return
|
||||||
hp_search = state.is_hyper_param_search
|
hp_search = state.is_hyper_param_search
|
||||||
if not self._initialized or hp_search:
|
if hp_search:
|
||||||
self.setup(args, state, model, reinit=hp_search, **kwargs)
|
self._wandb.finish()
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model, **kwargs)
|
||||||
|
|
||||||
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
|
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
|
||||||
if self._wandb is None:
|
if self._wandb is None:
|
||||||
return
|
return
|
||||||
# commit last step
|
|
||||||
if state.is_world_process_zero:
|
|
||||||
self._wandb.log({})
|
|
||||||
if self._log_model and self._initialized and state.is_world_process_zero:
|
if self._log_model and self._initialized and state.is_world_process_zero:
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
|
|
||||||
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
|
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
fake_trainer.save_model(temp_dir)
|
fake_trainer.save_model(temp_dir)
|
||||||
# use run name and ensure it's a valid Artifact name
|
|
||||||
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
|
|
||||||
metadata = (
|
metadata = (
|
||||||
{
|
{
|
||||||
k: v
|
k: v
|
||||||
@ -657,7 +652,7 @@ class WandbCallback(TrainerCallback):
|
|||||||
"train/total_floss": state.total_flos,
|
"train/total_floss": state.total_flos,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
|
artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
|
||||||
for f in Path(temp_dir).glob("*"):
|
for f in Path(temp_dir).glob("*"):
|
||||||
if f.is_file():
|
if f.is_file():
|
||||||
with artifact.new_file(f.name, mode="wb") as fa:
|
with artifact.new_file(f.name, mode="wb") as fa:
|
||||||
@ -668,10 +663,10 @@ class WandbCallback(TrainerCallback):
|
|||||||
if self._wandb is None:
|
if self._wandb is None:
|
||||||
return
|
return
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.setup(args, state, model, reinit=False)
|
self.setup(args, state, model)
|
||||||
if state.is_world_process_zero:
|
if state.is_world_process_zero:
|
||||||
logs = rewrite_logs(logs)
|
logs = rewrite_logs(logs)
|
||||||
self._wandb.log(logs, step=state.global_step)
|
self._wandb.log({**logs, "train/global_step": state.global_step})
|
||||||
|
|
||||||
|
|
||||||
class CometCallback(TrainerCallback):
|
class CometCallback(TrainerCallback):
|
||||||
|
Loading…
Reference in New Issue
Block a user