From 63ca6d9771b13b603deb228420623681188a4dc2 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 29 Oct 2024 08:26:04 +0100 Subject: [PATCH] Fix CI (#34458) * fix * fix mistral --- src/transformers/generation/flax_utils.py | 2 ++ tests/generation/test_flax_utils.py | 4 ++++ tests/test_modeling_common.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 08480ac983e..88535b44e9c 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -397,6 +397,8 @@ class FlaxGenerationMixin: "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + else: # by default let's always generate 10 new tokens + generation_config.max_length = generation_config.max_length + input_ids_seq_length if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( diff --git a/tests/generation/test_flax_utils.py b/tests/generation/test_flax_utils.py index 647482b88cd..bb0c1828763 100644 --- a/tests/generation/test_flax_utils.py +++ b/tests/generation/test_flax_utils.py @@ -101,6 +101,10 @@ class FlaxGenerationTesterMixin: pt_model = pt_model_class(config).eval() pt_model = load_flax_weights_in_pytorch_model(pt_model, flax_model.params) + # Generate max 5 tokens only otherwise seems to be numerical error accumulation + pt_model.generation_config.max_length = 5 + flax_model.generation_config.max_length = 5 + flax_generation_outputs = flax_model.generate(input_ids).sequences pt_generation_outputs = pt_model.generate(torch.tensor(input_ids, dtype=torch.long)) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 51d51dfcc28..d88b0dc5f02 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3002,7 +3002,7 @@ class ModelTesterMixin: def test_inputs_embeds_matches_input_ids_with_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: + for model_class in self.all_generative_model_classes: if model_class.__name__ not in [ *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES), *get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),