feat(wandb): logging and configuration improvements (#10826)

* feat: ensure unique artifact id

* feat: allow manual init

* fix: simplify reinit logic

* fix: no dropped value + immediate commits

* fix: wandb use in sagemaker

* docs: improve documenation and formatting

* fix: typos

* docs: improve formatting
This commit is contained in:
Boris Dayma 2021-03-22 09:45:17 -05:00 committed by GitHub
parent b230181d41
commit 125ccead71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 57 deletions

View File

@ -240,34 +240,11 @@ Whenever you use `Trainer` or `TFTrainer` classes, your losses, evaluation metri
Advanced configuration is possible by setting environment variables: Advanced configuration is possible by setting environment variables:
<table> | Environment Variable | Value |
<thead> |---|---|
<tr> | WANDB_LOG_MODEL | Log the model as artifact (log the model as artifact at the end of training (`false` by default) |
<th style="text-align:left">Environment Variables</th> | WANDB_WATCH | one of `gradients` (default) to log histograms of gradients, `all` to log histograms of both gradients and parameters, or `false` for no histogram logging |
<th style="text-align:left">Options</th> | WANDB_PROJECT | Organize runs by project |
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:left">WANDB_LOG_MODEL</td>
<td style="text-align:left">Log the model as artifact at the end of training (<b>false</b> by default)</td>
</tr>
<tr>
<td style="text-align:left">WANDB_WATCH</td>
<td style="text-align:left">
<ul>
<li><b>gradients</b> (default): Log histograms of the gradients</li>
<li><b>all</b>: Log histograms of gradients and parameters</li>
<li><b>false</b>: No gradient or parameter logging</li>
</ul>
</td>
</tr>
<tr>
<td style="text-align:left">WANDB_PROJECT</td>
<td style="text-align:left">Organize runs by project</td>
</tr>
</tbody>
</table>
Set run names with `run_name` argument present in scripts or as part of `TrainingArguments`. Set run names with `run_name` argument present in scripts or as part of `TrainingArguments`.

View File

@ -19,7 +19,6 @@ import io
import json import json
import numbers import numbers
import os import os
import re
import tempfile import tempfile
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
@ -559,20 +558,12 @@ class WandbCallback(TrainerCallback):
if has_wandb: if has_wandb:
import wandb import wandb
wandb.ensure_configured() self._wandb = wandb
if wandb.api.api_key is None:
has_wandb = False
logger.warning(
"W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
)
self._wandb = None
else:
self._wandb = wandb
self._initialized = False self._initialized = False
# log outputs # log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
def setup(self, args, state, model, reinit, **kwargs): def setup(self, args, state, model, **kwargs):
""" """
Setup the optional Weights & Biases (`wandb`) integration. Setup the optional Weights & Biases (`wandb`) integration.
@ -581,7 +572,8 @@ class WandbCallback(TrainerCallback):
Environment: Environment:
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`): WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to log model as artifact at the end of training. Whether or not to log model as artifact at the end of training. Use along with
`TrainingArguments.load_best_model_at_end` to upload best model.
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`): WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters. logging or :obj:`"all"` to log gradients and parameters.
@ -610,13 +602,19 @@ class WandbCallback(TrainerCallback):
else: else:
run_name = args.run_name run_name = args.run_name
self._wandb.init( if self._wandb.run is None:
project=os.getenv("WANDB_PROJECT", "huggingface"), self._wandb.init(
config=combined_dict, project=os.getenv("WANDB_PROJECT", "huggingface"),
name=run_name, name=run_name,
reinit=reinit, **init_args,
**init_args, )
) # add config parameters (run may have been created manually)
self._wandb.config.update(combined_dict, allow_val_change=True)
# define default x-axis (for latest wandb versions)
if getattr(self._wandb, "define_metric", None):
self._wandb.define_metric("train/global_step")
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
# keep track of model topology and gradients, unsupported on TPU # keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
@ -628,23 +626,20 @@ class WandbCallback(TrainerCallback):
if self._wandb is None: if self._wandb is None:
return return
hp_search = state.is_hyper_param_search hp_search = state.is_hyper_param_search
if not self._initialized or hp_search: if hp_search:
self.setup(args, state, model, reinit=hp_search, **kwargs) self._wandb.finish()
if not self._initialized:
self.setup(args, state, model, **kwargs)
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None: if self._wandb is None:
return return
# commit last step
if state.is_world_process_zero:
self._wandb.log({})
if self._log_model and self._initialized and state.is_world_process_zero: if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer from .trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer) fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir) fake_trainer.save_model(temp_dir)
# use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
metadata = ( metadata = (
{ {
k: v k: v
@ -657,7 +652,7 @@ class WandbCallback(TrainerCallback):
"train/total_floss": state.total_flos, "train/total_floss": state.total_flos,
} }
) )
artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"): for f in Path(temp_dir).glob("*"):
if f.is_file(): if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa: with artifact.new_file(f.name, mode="wb") as fa:
@ -668,10 +663,10 @@ class WandbCallback(TrainerCallback):
if self._wandb is None: if self._wandb is None:
return return
if not self._initialized: if not self._initialized:
self.setup(args, state, model, reinit=False) self.setup(args, state, model)
if state.is_world_process_zero: if state.is_world_process_zero:
logs = rewrite_logs(logs) logs = rewrite_logs(logs)
self._wandb.log(logs, step=state.global_step) self._wandb.log({**logs, "train/global_step": state.global_step})
class CometCallback(TrainerCallback): class CometCallback(TrainerCallback):