mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Error with checking args.eval_accumulation_steps to gather tensors (#25819)
* Update trainer.py (error with checking steps in args.eval_accumulation_steps to gather tensors) While the deprecated code has the correct check (line 3772): "if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:" The current code does not (line 3196): "if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients:" We need to check "(step + 1) % args.eval_accumulation_steps == 0". Hence, the line 3196 should be modified to: "if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0 and self.accelerator.sync_gradients:" * Fix error with checking args.eval_accumulation_steps to gather tensors
This commit is contained in:
parent
33aa0af70c
commit
483861d52d
@ -3193,7 +3193,11 @@ class Trainer:
|
||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients:
|
||||
if (
|
||||
args.eval_accumulation_steps is not None
|
||||
and (step + 1) % args.eval_accumulation_steps == 0
|
||||
and self.accelerator.sync_gradients
|
||||
):
|
||||
if losses_host is not None:
|
||||
losses = nested_numpify(losses_host)
|
||||
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
|
||||
|
Loading…
Reference in New Issue
Block a user