Fix jetmoe model (#31279)

* Fix jetmoe model

* Remove skip-tests
This commit is contained in:
Cyril Vallez 2024-06-07 11:51:41 +02:00 committed by GitHub
parent f868cf731a
commit 8bcf9c8dd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 21 deletions

View File

@ -1404,18 +1404,14 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = ( max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None if past_key_values.get_max_length() is not None
else None else None
) )
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@ -1446,7 +1442,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise

View File

@ -472,14 +472,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_flash_attn_2_inference_equivalence_right_padding(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("JetMoe flash attention does not support right padding") self.skipTest("JetMoe flash attention does not support right padding")
@unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ")
def test_beam_sample_generate(self):
pass
@unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ")
def test_generate_from_inputs_embeds_decoder_only(self):
pass
@require_torch @require_torch
class JetMoeIntegrationTest(unittest.TestCase): class JetMoeIntegrationTest(unittest.TestCase):