fix: loss computation after embeddings resize - mllama (#36840)

* move loss to generation class

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* code cleanup

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* test for resize and loss computation

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix tests

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:test for resize and loss

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix resize embedding mllama test

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* review changes

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

---------

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
This commit is contained in:
Sukriti Sharma 2025-03-21 07:47:59 -06:00 committed by GitHub
parent 4542b8fb27
commit 90e2df5d55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 2 deletions

View File

@ -2056,6 +2056,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -2157,15 +2158,31 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
labels=labels,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**loss_kwargs,
)
return outputs
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
loss = None
logits = outputs[0]
if labels is not None:
loss = self.loss_function(logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs)
if not return_dict:
return (loss,) + outputs if loss is not None else outputs
return CausalLMOutputWithPast(
loss=loss,
logits=outputs.logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,

View File

@ -321,6 +321,24 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_resize_embeddings_results_in_successful_loss(self):
# resizing embeddings should result in successful loss computation
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model_vocab_size = config.get_text_config().vocab_size
inputs = self._prepare_for_class(inputs, model_class, return_labels=True)
# Resize embeddings and call forward
model.resize_token_embeddings(model_vocab_size + 10)
output = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=inputs["labels"],
return_dict=True,
)
self.assertTrue("loss" in output)
def _check_attentions_for_generate(
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
):