mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
enable average tokens across devices (#34373)
* enable average tokens across devices * reduce earlier in case model needs it * simplify if statement * reformat code to make ruff happy * add doc for argument: average_tokens_across_devices * cannot find world size when pytorch is unavailable * format code --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
a17f287ac0
commit
d21dbd1520
@ -3631,7 +3631,12 @@ class Trainer:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss *= self.args.gradient_accumulation_steps
|
||||
if num_items_in_batch is not None:
|
||||
if self.compute_loss_func or self.model_accepts_loss_kwargs:
|
||||
loss *= self.args.gradient_accumulation_steps
|
||||
# Average tokens across devices is orthogonal to gradient accumulation
|
||||
if self.args.average_tokens_across_devices:
|
||||
loss *= self.args.world_size
|
||||
self.accelerator.backward(loss, **kwargs)
|
||||
|
||||
return loss.detach() / self.args.gradient_accumulation_steps
|
||||
@ -3646,6 +3651,9 @@ class Trainer:
|
||||
labels = inputs.pop("labels")
|
||||
else:
|
||||
labels = None
|
||||
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
|
||||
num_items_in_batch_tensor = torch.tensor(num_items_in_batch, device=self.args.device)
|
||||
num_items_in_batch = int(self.accelerator.gather(num_items_in_batch_tensor).sum().cpu())
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
|
@ -1532,6 +1532,15 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
|
||||
average_tokens_across_devices: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to "
|
||||
"synchronize num_tokens_in_batch for precise loss calculation. Reference: "
|
||||
"https://github.com/huggingface/transformers/issues/34242"
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||
for field in _VALID_DICT_FIELDS:
|
||||
@ -1765,6 +1774,19 @@ class TrainingArguments:
|
||||
if self.framework == "pt" and is_torch_available():
|
||||
self.device
|
||||
|
||||
# Disable average tokens when using single device
|
||||
if self.average_tokens_across_devices:
|
||||
try:
|
||||
if self.world_size == 1:
|
||||
logger.warning(
|
||||
"average_tokens_across_devices is set to True but it is invalid when world size is"
|
||||
"1. Turn it to False automatically."
|
||||
)
|
||||
self.average_tokens_across_devices = False
|
||||
except ImportError as e:
|
||||
logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.")
|
||||
self.average_tokens_across_devices = False
|
||||
|
||||
if self.torchdynamo is not None:
|
||||
warnings.warn(
|
||||
"`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||
|
Loading…
Reference in New Issue
Block a user