mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix two bugs with --logging_first_step (#8193)
* make sure that logging_first_step evaluates * fix bug with incorrect loss on logging_first_step * fix style * logging_first_step only logs, not evals
This commit is contained in:
parent
689ff74f99
commit
8f1c960ee7
@ -729,6 +729,7 @@ class Trainer:
|
||||
|
||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||
self._logging_loss_scalar = 0
|
||||
self._globalstep_last_logged = 0
|
||||
self._total_flos = self.state.total_flos
|
||||
model.zero_grad()
|
||||
|
||||
@ -849,7 +850,9 @@ class Trainer:
|
||||
if self.control.should_log:
|
||||
logs: Dict[str, float] = {}
|
||||
tr_loss_scalar = tr_loss.item()
|
||||
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / self.args.logging_steps
|
||||
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / (
|
||||
self.state.global_step - self._globalstep_last_logged
|
||||
)
|
||||
# backward compatibility for pytorch schedulers
|
||||
logs["learning_rate"] = (
|
||||
self.lr_scheduler.get_last_lr()[0]
|
||||
@ -857,6 +860,7 @@ class Trainer:
|
||||
else self.lr_scheduler.get_lr()[0]
|
||||
)
|
||||
self._logging_loss_scalar = tr_loss_scalar
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
|
||||
self.log(logs)
|
||||
|
||||
|
@ -250,7 +250,7 @@ class TrainingArguments:
|
||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||
|
||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||
logging_first_step: bool = field(default=False, metadata={"help": "Log and eval the first global_step"})
|
||||
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
|
||||
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
||||
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
||||
save_total_limit: Optional[int] = field(
|
||||
|
Loading…
Reference in New Issue
Block a user