mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
fix: loss computation after embeddings resize - mllama (#36840)
* move loss to generation class Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * code cleanup Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * test for resize and loss computation Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix tests Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix:test for resize and loss Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * fix resize embedding mllama test Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> * review changes Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com> --------- Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
This commit is contained in:
parent
4542b8fb27
commit
90e2df5d55
@ -2056,6 +2056,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -2157,15 +2158,31 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
inputs_embeds=inputs_embeds,
|
||||
labels=labels,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
||||
loss = None
|
||||
logits = outputs[0]
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
return (loss,) + outputs if loss is not None else outputs
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=outputs.logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
|
@ -321,6 +321,24 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def test_resize_embeddings_results_in_successful_loss(self):
|
||||
# resizing embeddings should result in successful loss computation
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
inputs = self._prepare_for_class(inputs, model_class, return_labels=True)
|
||||
# Resize embeddings and call forward
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
output = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
labels=inputs["labels"],
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue("loss" in output)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user