From 818463ee8eaf3a1cd5ddc2623789cbd7bb517d02 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Mon, 4 May 2020 21:42:27 -0500 Subject: [PATCH] Trainer: add logging through Weights & Biases (#3916) * feat: add logging through Weights & Biases * feat(wandb): make logging compatible with all scripts * style(trainer.py): fix formatting * [Trainer] Tweak wandb integration Co-authored-by: Julien Chaumond --- .gitignore | 1 + src/transformers/trainer.py | 30 +++++++++++++++++++++++++++++- src/transformers/training_args.py | 12 +++++++++++- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 5a9bc779b80..d366cdddd6e 100644 --- a/.gitignore +++ b/.gitignore @@ -131,6 +131,7 @@ proc_data # examples runs /runs_old +/wandb examples/runs # data diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 661c60f88ec..d00cabc072b 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -52,6 +52,18 @@ def is_tensorboard_available(): return _has_tensorboard +try: + import wandb + + _has_wandb = True +except ImportError: + _has_wandb = False + + +def is_wandb_available(): + return _has_wandb + + logger = logging.getLogger(__name__) @@ -151,6 +163,10 @@ class Trainer: logger.warning( "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it." ) + if not is_wandb_available(): + logger.info( + "You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging." + ) set_seed(self.args.seed) # Create output directory if needed if self.args.local_rank in [-1, 0]: @@ -209,6 +225,12 @@ class Trainer: ) return optimizer, scheduler + def _setup_wandb(self): + # Start a wandb run and log config parameters + wandb.init(name=self.args.logging_dir, config=vars(self.args)) + # keep track of model topology and gradients + # wandb.watch(self.model) + def train(self, model_path: Optional[str] = None): """ Main training entry point. @@ -263,6 +285,9 @@ class Trainer: if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) + self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) + if is_wandb_available(): + self._setup_wandb() # Train! logger.info("***** Running training *****") @@ -351,6 +376,9 @@ class Trainer: if self.tb_writer: for k, v in logs.items(): self.tb_writer.add_scalar(k, v, global_step) + if is_wandb_available(): + wandb.log(logs, step=global_step) + epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: @@ -467,7 +495,7 @@ class Trainer: shutil.rmtree(checkpoint) def evaluate( - self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None + self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None, ) -> Dict[str, float]: """ Run evaluation and return metrics. diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c4bc9b6456f..5bffd44b0d3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2,7 +2,7 @@ import dataclasses import json import logging from dataclasses import dataclass, field -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple from .file_utils import cached_property, is_torch_available, torch_required @@ -138,3 +138,13 @@ class TrainingArguments: Serializes this instance to a JSON string. """ return json.dumps(dataclasses.asdict(self), indent=2) + + def to_sanitized_dict(self) -> Dict[str, Any]: + """ + Sanitized serialization to use with TensorBoard’s hparams + """ + d = dataclasses.asdict(self) + valid_types = [bool, int, float, str] + if is_torch_available(): + valid_types.append(torch.Tensor) + return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}