Fix jetmoe model (#31279)

* Fix jetmoe model

* Remove skip-tests
This commit is contained in:
Cyril Vallez 2024-06-07 11:51:41 +02:00 committed by GitHub
parent f868cf731a
commit 8bcf9c8dd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 21 deletions

View File

@ -1404,18 +1404,14 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@ -1446,7 +1442,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise

View File

@ -472,14 +472,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("JetMoe flash attention does not support right padding")
@unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ")
def test_beam_sample_generate(self):
pass
@unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ")
def test_generate_from_inputs_embeds_decoder_only(self):
pass
@require_torch
class JetMoeIntegrationTest(unittest.TestCase):