mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
f868cf731a
commit
8bcf9c8dd4
@ -1404,18 +1404,14 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
|
||||
past_length = 0
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
max_cache_length = None
|
||||
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
|
||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||
max_cache_length = (
|
||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
||||
if past_key_values.get_max_length() is not None
|
||||
else None
|
||||
)
|
||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
@ -1446,7 +1442,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
if inputs_embeds is not None and past_length == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
|
@ -472,14 +472,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
self.skipTest("JetMoe flash attention does not support right padding")
|
||||
|
||||
@unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ")
|
||||
def test_beam_sample_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ")
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class JetMoeIntegrationTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user