mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Trainer: add logging through Weights & Biases (#3916)
* feat: add logging through Weights & Biases * feat(wandb): make logging compatible with all scripts * style(trainer.py): fix formatting * [Trainer] Tweak wandb integration Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
858b1d1e5a
commit
818463ee8e
1
.gitignore
vendored
1
.gitignore
vendored
@ -131,6 +131,7 @@ proc_data
|
||||
# examples
|
||||
runs
|
||||
/runs_old
|
||||
/wandb
|
||||
examples/runs
|
||||
|
||||
# data
|
||||
|
@ -52,6 +52,18 @@ def is_tensorboard_available():
|
||||
return _has_tensorboard
|
||||
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
_has_wandb = True
|
||||
except ImportError:
|
||||
_has_wandb = False
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return _has_wandb
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -151,6 +163,10 @@ class Trainer:
|
||||
logger.warning(
|
||||
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
||||
)
|
||||
if not is_wandb_available():
|
||||
logger.info(
|
||||
"You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging."
|
||||
)
|
||||
set_seed(self.args.seed)
|
||||
# Create output directory if needed
|
||||
if self.args.local_rank in [-1, 0]:
|
||||
@ -209,6 +225,12 @@ class Trainer:
|
||||
)
|
||||
return optimizer, scheduler
|
||||
|
||||
def _setup_wandb(self):
|
||||
# Start a wandb run and log config parameters
|
||||
wandb.init(name=self.args.logging_dir, config=vars(self.args))
|
||||
# keep track of model topology and gradients
|
||||
# wandb.watch(self.model)
|
||||
|
||||
def train(self, model_path: Optional[str] = None):
|
||||
"""
|
||||
Main training entry point.
|
||||
@ -263,6 +285,9 @@ class Trainer:
|
||||
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_text("args", self.args.to_json_string())
|
||||
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
|
||||
if is_wandb_available():
|
||||
self._setup_wandb()
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
@ -351,6 +376,9 @@ class Trainer:
|
||||
if self.tb_writer:
|
||||
for k, v in logs.items():
|
||||
self.tb_writer.add_scalar(k, v, global_step)
|
||||
if is_wandb_available():
|
||||
wandb.log(logs, step=global_step)
|
||||
|
||||
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))
|
||||
|
||||
if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
|
||||
@ -467,7 +495,7 @@ class Trainer:
|
||||
shutil.rmtree(checkpoint)
|
||||
|
||||
def evaluate(
|
||||
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None
|
||||
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Run evaluation and return metrics.
|
||||
|
@ -2,7 +2,7 @@ import dataclasses
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from .file_utils import cached_property, is_torch_available, torch_required
|
||||
|
||||
@ -138,3 +138,13 @@ class TrainingArguments:
|
||||
Serializes this instance to a JSON string.
|
||||
"""
|
||||
return json.dumps(dataclasses.asdict(self), indent=2)
|
||||
|
||||
def to_sanitized_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Sanitized serialization to use with TensorBoard’s hparams
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
valid_types = [bool, int, float, str]
|
||||
if is_torch_available():
|
||||
valid_types.append(torch.Tensor)
|
||||
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
|
||||
|
Loading…
Reference in New Issue
Block a user