mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
feat(TFTrainer): improve logging (#4946)
* feat(tftrainer): improve logging * fix(trainer): consider case with evaluation only * refactor(tftrainer): address comments * refactor(tftrainer): move self.epoch_logging to __init__
This commit is contained in:
parent
7b5a1e7d51
commit
1bf4098e03
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@ -553,6 +552,9 @@ class Trainer:
|
||||
def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
||||
if self.epoch is not None:
|
||||
logs["epoch"] = self.epoch
|
||||
if self.global_step is None:
|
||||
# when logging evaluation metrics without training
|
||||
self.global_step = 0
|
||||
if self.tb_writer:
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
@ -571,11 +573,11 @@ class Trainer:
|
||||
if is_wandb_available():
|
||||
if self.is_world_master():
|
||||
wandb.log(logs, step=self.global_step)
|
||||
output = json.dumps({**logs, **{"step": self.global_step}})
|
||||
output = {**logs, **{"step": self.global_step}}
|
||||
if iterator is not None:
|
||||
iterator.write(output)
|
||||
else:
|
||||
print(output)
|
||||
logger.info(output)
|
||||
|
||||
def _training_step(
|
||||
self, model: nn.Module, inputs: Dict[str, torch.Tensor], optimizer: torch.optim.Optimizer
|
||||
|
@ -14,6 +14,23 @@ from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutp
|
||||
from .training_args_tf import TFTrainingArguments
|
||||
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key is None:
|
||||
_has_wandb = False
|
||||
wandb.termwarn("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
|
||||
else:
|
||||
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
|
||||
except (ImportError, AttributeError):
|
||||
_has_wandb = False
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return _has_wandb
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -27,7 +44,7 @@ class TFTrainer:
|
||||
tb_writer: Optional[tf.summary.SummaryWriter] = None
|
||||
optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None
|
||||
global_step: Optional[int] = None
|
||||
epoch: Optional[float] = None
|
||||
epoch_logging: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -48,11 +65,20 @@ class TFTrainer:
|
||||
self.prediction_loss_only = prediction_loss_only
|
||||
self.optimizers = optimizers
|
||||
self.gradient_accumulator = GradientAccumulator()
|
||||
self.global_step = 0
|
||||
self.epoch_logging = 0
|
||||
|
||||
if tb_writer is not None:
|
||||
self.tb_writer = tb_writer
|
||||
else:
|
||||
self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir)
|
||||
if is_wandb_available():
|
||||
self._setup_wandb()
|
||||
else:
|
||||
logger.info(
|
||||
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
||||
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
|
||||
)
|
||||
|
||||
def get_train_tfdataset(self) -> tf.data.Dataset:
|
||||
if self.train_dataset is None:
|
||||
@ -118,6 +144,22 @@ class TFTrainer:
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
def _setup_wandb(self):
|
||||
"""
|
||||
Setup the optional Weights & Biases (`wandb`) integration.
|
||||
|
||||
One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface
|
||||
You can also override the following environment variables:
|
||||
|
||||
Environment:
|
||||
WANDB_PROJECT:
|
||||
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
|
||||
WANDB_DISABLED:
|
||||
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
|
||||
"""
|
||||
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
|
||||
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
|
||||
|
||||
@tf.function
|
||||
def _evaluate_steps(self, per_replica_features, per_replica_labels):
|
||||
"""
|
||||
@ -208,6 +250,17 @@ class TFTrainer:
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
def _log(self, logs: Dict[str, float]) -> None:
|
||||
if self.tb_writer:
|
||||
with self.tb_writer.as_default():
|
||||
for k, v in logs.items():
|
||||
tf.summary.scalar(k, v, step=self.global_step)
|
||||
self.tb_writer.flush()
|
||||
if is_wandb_available():
|
||||
wandb.log(logs, step=self.global_step)
|
||||
output = {**logs, **{"step": self.global_step}}
|
||||
logger.info(output)
|
||||
|
||||
def evaluate(
|
||||
self, eval_dataset: Optional[tf.data.Dataset] = None, prediction_loss_only: Optional[bool] = None
|
||||
) -> Dict[str, float]:
|
||||
@ -218,6 +271,10 @@ class TFTrainer:
|
||||
|
||||
output = self._prediction_loop(eval_ds, description="Evaluation")
|
||||
|
||||
logs = {**output.metrics}
|
||||
logs["epoch"] = self.epoch_logging
|
||||
self._log(logs)
|
||||
|
||||
return output.metrics
|
||||
|
||||
def train(self) -> None:
|
||||
@ -269,44 +326,38 @@ class TFTrainer:
|
||||
logger.info(" Num Epochs = %d", epochs)
|
||||
logger.info(" Total optimization steps = %d", self.train_steps)
|
||||
|
||||
for epoch in range(start_epoch, int(epochs + 1)):
|
||||
for training_loss in self._training_steps(train_ds, optimizer):
|
||||
step = iterations.numpy()
|
||||
for epoch_iter in range(start_epoch, int(epochs + 1)):
|
||||
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
|
||||
self.global_step = iterations.numpy()
|
||||
self.epoch_logging = epoch_iter - 1 + (step + 1) / self.train_steps
|
||||
|
||||
if self.args.debug:
|
||||
with self.tb_writer.as_default():
|
||||
tf.summary.scalar("loss", training_loss, step=step)
|
||||
|
||||
if step == 1 and self.args.debug:
|
||||
with self.tb_writer.as_default():
|
||||
tf.summary.trace_export(name="training", step=step, profiler_outdir=self.args.logging_dir)
|
||||
|
||||
if self.args.evaluate_during_training and step % self.args.eval_steps == 0:
|
||||
logs = {}
|
||||
results = self.evaluate()
|
||||
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
logs["learning_rate"] = lr_scheduler(step).numpy()
|
||||
|
||||
logger.info("Epoch {} Step {} Validation Metrics {}".format(epoch, step, logs))
|
||||
logs["loss"] = training_loss.numpy()
|
||||
logs["epoch"] = self.epoch_logging
|
||||
self._log(logs)
|
||||
|
||||
if self.global_step == 1 and self.args.debug:
|
||||
with self.tb_writer.as_default():
|
||||
for k, v in logs.items():
|
||||
tf.summary.scalar(k, v, step=step)
|
||||
tf.summary.trace_export(
|
||||
name="training", step=self.global_step, profiler_outdir=self.args.logging_dir
|
||||
)
|
||||
|
||||
self.tb_writer.flush()
|
||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
||||
self.evaluate()
|
||||
|
||||
if step % self.args.logging_steps == 0:
|
||||
logger.info("Epoch {} Step {} Train Loss {:.4f}".format(epoch, step, training_loss.numpy()))
|
||||
if self.global_step % self.args.logging_steps == 0:
|
||||
logs = {}
|
||||
logs["loss"] = training_loss.numpy()
|
||||
logs["learning_rate"] = lr_scheduler(self.global_step).numpy()
|
||||
logs["epoch"] = self.epoch_logging
|
||||
self._log(logs)
|
||||
|
||||
if step % self.args.save_steps == 0:
|
||||
if self.global_step % self.args.save_steps == 0:
|
||||
ckpt_save_path = self.model.ckpt_manager.save()
|
||||
logger.info("Saving checkpoint for step {} at {}".format(step, ckpt_save_path))
|
||||
logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path))
|
||||
|
||||
if step % self.train_steps == 0:
|
||||
if self.global_step % self.train_steps == 0:
|
||||
break
|
||||
|
||||
def _training_steps(self, ds, optimizer):
|
||||
|
Loading…
Reference in New Issue
Block a user