diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4315e54a42f..9176bd72a55 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3631,7 +3631,12 @@ class Trainer: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: - loss *= self.args.gradient_accumulation_steps + if num_items_in_batch is not None: + if self.compute_loss_func or self.model_accepts_loss_kwargs: + loss *= self.args.gradient_accumulation_steps + # Average tokens across devices is orthogonal to gradient accumulation + if self.args.average_tokens_across_devices: + loss *= self.args.world_size self.accelerator.backward(loss, **kwargs) return loss.detach() / self.args.gradient_accumulation_steps @@ -3646,6 +3651,9 @@ class Trainer: labels = inputs.pop("labels") else: labels = None + if self.args.average_tokens_across_devices and num_items_in_batch is not None: + num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device) + num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu()) if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c98e8bc41b9..3e5c6cc2f37 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1532,6 +1532,15 @@ class TrainingArguments: }, ) + average_tokens_across_devices: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to " + "synchronize num_tokens_in_batch for precise loss calculation. Reference: " + "https://github.com/huggingface/transformers/issues/34242" + }, + ) + def __post_init__(self): # Parse in args that could be `dict` sent in from the CLI as a string for field in _VALID_DICT_FIELDS: @@ -1765,6 +1774,19 @@ class TrainingArguments: if self.framework == "pt" and is_torch_available(): self.device + # Disable average tokens when using single device + if self.average_tokens_across_devices: + try: + if self.world_size == 1: + logger.warning( + "average_tokens_across_devices is set to True but it is invalid when world size is" + "1. Turn it to False automatically." + ) + self.average_tokens_across_devices = False + except ImportError as e: + logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.") + self.average_tokens_across_devices = False + if self.torchdynamo is not None: warnings.warn( "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"