Enable Gradient Accumulation fix across all models + trainer fully in forward() (#34283)

* Enable grad accum fix across all models + trainer fully in forward()

* handle peft case

* Account for DDP: need to run scale tests

* Use accelerator state

* Quality

* Guard

* Experiment w/ only fairseq fix

* Fairseq only

* Revert multiply_grads fix

* Mult by grad accum to fully bring back solution

* Style

* Good to go now

* Skip fx tests for now

* Bookmark

* Working now
This commit is contained in:
Zach Mueller 2024-10-23 11:24:57 -04:00 committed by GitHub
parent 1fb575fcf0
commit d9f733625c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 81 additions and 31 deletions

View File

@ -1114,6 +1114,7 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1172,7 +1173,7 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1030,6 +1030,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1087,7 +1088,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -961,6 +961,7 @@ class GemmaForCausalLM(LlamaForCausalLM):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
```python
@ -1003,7 +1004,7 @@ class GemmaForCausalLM(LlamaForCausalLM):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1002,6 +1002,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1068,7 +1069,7 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -756,6 +756,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
```python
@ -807,7 +808,7 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1014,6 +1014,7 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1071,7 +1072,7 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1450,6 +1450,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: Optional[Union[int, None]] = None,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
@ -1515,7 +1516,7 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
aux_loss = None
if output_router_logits:

View File

@ -1240,6 +1240,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
@ -1303,7 +1304,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
aux_loss = None
if output_router_logits:

View File

@ -1887,6 +1887,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1949,7 +1950,7 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1028,6 +1028,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1085,7 +1086,7 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1068,6 +1068,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1126,7 +1127,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1228,6 +1228,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
@ -1290,7 +1291,7 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
aux_loss = None
if output_router_logits:

View File

@ -1192,6 +1192,7 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1250,7 +1251,7 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1209,6 +1209,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1275,7 +1276,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1377,6 +1377,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
@ -1442,7 +1443,7 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
aux_loss = None
if output_router_logits:

View File

@ -1121,6 +1121,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1179,7 +1180,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -1305,6 +1305,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
@ -1367,7 +1368,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
aux_loss = None
if output_router_logits:

View File

@ -2027,6 +2027,7 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
) -> Union[Tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*):
@ -2128,6 +2129,7 @@ class RTDetrForObjectDetection(RTDetrPreTrainedModel):
enc_topk_logits=enc_topk_logits,
enc_topk_bboxes=enc_topk_bboxes,
denoising_meta_values=denoising_meta_values,
**loss_kwargs,
)
if not return_dict:

View File

@ -1418,6 +1418,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1477,7 +1478,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@ -582,6 +582,16 @@ class Trainer:
self.model_wrapped = model
self.model = model
# Just in case the model was wrapped outside of the `Trainer`
unwrapped_model = self.accelerator.unwrap_model(model)
model_forward = (
unwrapped_model.forward
if not _is_peft_model(unwrapped_model)
else unwrapped_model.get_base_model().forward
)
self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters
self.neftune_noise_alpha = args.neftune_noise_alpha
self.compute_metrics = compute_metrics
@ -2417,8 +2427,14 @@ class Trainer:
for inputs in batch_samples:
step += 1
total_batched_samples += 1
is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)
do_sync_step = is_last_step_and_steps_less_than_grad_acc or (
total_batched_samples % args.gradient_accumulation_steps == 0
)
# Since we perform prefetching, we need to manually set sync_gradients
if total_batched_samples % args.gradient_accumulation_steps != 0:
if not do_sync_step:
self.accelerator.gradient_state._set_sync_gradients(False)
else:
self.accelerator.gradient_state._set_sync_gradients(True)
@ -2473,16 +2489,7 @@ class Trainer:
self.current_flos += float(self.floating_point_ops(inputs))
is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)
if (
(total_batched_samples) % args.gradient_accumulation_steps == 0
or
# last step in epoch but step is always smaller than gradient_accumulation_steps
is_last_step_and_steps_less_than_grad_acc
):
if do_sync_step:
# Since we perform prefetching, we need to manually set sync_gradients to True
self.accelerator.gradient_state._set_sync_gradients(True)
@ -3610,8 +3617,11 @@ class Trainer:
labels = inputs.pop("labels")
else:
labels = None
# if num_items_in_batch is not None:
# inputs["num_items_in_batch"] = num_items_in_batch
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.

View File

@ -304,6 +304,10 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()
@require_bitsandbytes
@require_torch_sdpa
@require_torch_multi_gpu

View File

@ -356,6 +356,10 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()
def test_Mistral_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)

View File

@ -356,6 +356,10 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()
def test_Mixtral_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)

View File

@ -368,6 +368,10 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()
def test_Qwen2_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)

View File

@ -391,6 +391,10 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="PR #34283 made changes to the forward function.")
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()
def test_Qwen2Moe_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)