enable average tokens across devices (#34373)

* enable average tokens across devices

* reduce earlier in case model needs it

* simplify if statement

* reformat code to make ruff happy

* add doc for argument: average_tokens_across_devices

* cannot find world size when pytorch is unavailable

* format code

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
kang sheng 2024-10-29 01:59:38 +08:00 committed by GitHub
parent a17f287ac0
commit d21dbd1520
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 1 deletions

View File

@ -3631,7 +3631,12 @@ class Trainer:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss *= self.args.gradient_accumulation_steps
if num_items_in_batch is not None:
if self.compute_loss_func or self.model_accepts_loss_kwargs:
loss *= self.args.gradient_accumulation_steps
# Average tokens across devices is orthogonal to gradient accumulation
if self.args.average_tokens_across_devices:
loss *= self.args.world_size
self.accelerator.backward(loss, **kwargs)
return loss.detach() / self.args.gradient_accumulation_steps
@ -3646,6 +3651,9 @@ class Trainer:
labels = inputs.pop("labels")
else:
labels = None
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:

View File

@ -1532,6 +1532,15 @@ class TrainingArguments:
},
)
average_tokens_across_devices: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to "
"synchronize num_tokens_in_batch for precise loss calculation. Reference: "
"https://github.com/huggingface/transformers/issues/34242"
},
)
def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS:
@ -1765,6 +1774,19 @@ class TrainingArguments:
if self.framework == "pt" and is_torch_available():
self.device
# Disable average tokens when using single device
if self.average_tokens_across_devices:
try:
if self.world_size == 1:
logger.warning(
"average_tokens_across_devices is set to True but it is invalid when world size is"
"1. Turn it to False automatically."
)
self.average_tokens_across_devices = False
except ImportError as e:
logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.")
self.average_tokens_across_devices = False
if self.torchdynamo is not None:
warnings.warn(
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"