Trainer: add logging through Weights & Biases (#3916)

* feat: add logging through Weights & Biases

* feat(wandb): make logging compatible with all scripts

* style(trainer.py): fix formatting

* [Trainer] Tweak wandb integration

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Boris Dayma 2020-05-04 21:42:27 -05:00 committed by GitHub
parent 858b1d1e5a
commit 818463ee8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 2 deletions

1
.gitignore vendored
View File

@ -131,6 +131,7 @@ proc_data
# examples
runs
/runs_old
/wandb
examples/runs
# data

View File

@ -52,6 +52,18 @@ def is_tensorboard_available():
return _has_tensorboard
try:
import wandb
_has_wandb = True
except ImportError:
_has_wandb = False
def is_wandb_available():
return _has_wandb
logger = logging.getLogger(__name__)
@ -151,6 +163,10 @@ class Trainer:
logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
)
if not is_wandb_available():
logger.info(
"You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging."
)
set_seed(self.args.seed)
# Create output directory if needed
if self.args.local_rank in [-1, 0]:
@ -209,6 +225,12 @@ class Trainer:
)
return optimizer, scheduler
def _setup_wandb(self):
# Start a wandb run and log config parameters
wandb.init(name=self.args.logging_dir, config=vars(self.args))
# keep track of model topology and gradients
# wandb.watch(self.model)
def train(self, model_path: Optional[str] = None):
"""
Main training entry point.
@ -263,6 +285,9 @@ class Trainer:
if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string())
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
if is_wandb_available():
self._setup_wandb()
# Train!
logger.info("***** Running training *****")
@ -351,6 +376,9 @@ class Trainer:
if self.tb_writer:
for k, v in logs.items():
self.tb_writer.add_scalar(k, v, global_step)
if is_wandb_available():
wandb.log(logs, step=global_step)
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))
if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
@ -467,7 +495,7 @@ class Trainer:
shutil.rmtree(checkpoint)
def evaluate(
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
) -> Dict[str, float]:
"""
Run evaluation and return metrics.

View File

@ -2,7 +2,7 @@ import dataclasses
import json
import logging
from dataclasses import dataclass, field
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple
from .file_utils import cached_property, is_torch_available, torch_required
@ -138,3 +138,13 @@ class TrainingArguments:
Serializes this instance to a JSON string.
"""
return json.dumps(dataclasses.asdict(self), indent=2)
def to_sanitized_dict(self) -> Dict[str, Any]:
"""
Sanitized serialization to use with TensorBoards hparams
"""
d = dataclasses.asdict(self)
valid_types = [bool, int, float, str]
if is_torch_available():
valid_types.append(torch.Tensor)
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}