diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index efe953db051..6e6d5b8bdce 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -378,10 +378,14 @@ class GenerationMixin: # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) if past_key_values is not None: model_inputs["past_key_values"] = past_key_values - if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3 + if ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] @@ -414,7 +418,7 @@ class GenerationMixin: for model_input_name in ["position_ids", "token_type_ids"]: model_input = kwargs.get(model_input_name) if model_input is not None: - if past_key_values: + if past_key_values is not None: model_input = model_input[:, -input_ids.shape[1] :] model_input = model_input.clone(memory_format=torch.contiguous_format) model_inputs[model_input_name] = model_input @@ -568,27 +572,34 @@ class GenerationMixin: def _prepare_attention_mask_for_generation( self, - inputs: torch.Tensor, - pad_token_id: Optional[torch.Tensor], - eos_token_id: Optional[torch.Tensor], + inputs_tensor: torch.Tensor, + generation_config: GenerationConfig, + model_kwargs: Dict[str, Any], ) -> torch.LongTensor: + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor + + # `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model) + if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: + inputs_tensor = model_kwargs["input_ids"] + # No information for attention mask inference -> return default attention mask - default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device) if pad_token_id is None: return default_attention_mask - is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] + is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] if not is_input_ids: return default_attention_mask is_pad_token_in_inputs = (pad_token_id is not None) and ( - isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any() + isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any() ) is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any() ) can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id - attention_mask_from_padding = inputs.ne(pad_token_id).long() + attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long() attention_mask = ( attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask @@ -2020,7 +2031,7 @@ class GenerationMixin: if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor + inputs_tensor, generation_config, model_kwargs ) elif kwargs_has_attention_mask: # TODO (joao): generalize this check with other types of inputs diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index c40ee1f70f9..85c109919da 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -911,7 +911,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " + "and must specify either one" ) legacy_processing = False diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 1425a017dc0..2025140bb6e 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -424,7 +424,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " + "and must specify either one" ) legacy_processing = False diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index f8bdb5bf8d5..2aa6b2fa1d6 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -657,7 +657,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None: raise ValueError( - "You cannot specify both pixel_values/pixel_values_videos and inputs_embeds at the same time, and must specify either one" + "You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, " + "and must specify either one" ) if inputs_embeds is None: diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index c18e1d1c9d8..109ddfb626d 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1562,7 +1562,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin): if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor + input_ids, generation_config, model_kwargs ) # 5. Prepare `max_length` depending on other stopping criteria. @@ -2578,7 +2578,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor + inputs_tensor, generation_config, model_kwargs ) if "encoder_outputs" not in model_kwargs: diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index d2f339afc41..61f2ce414e1 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1484,7 +1484,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin): if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor + input_ids, generation_config, model_kwargs ) # 5. Prepare `max_length` depending on other stopping criteria. @@ -2425,7 +2425,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor + inputs_tensor, generation_config, model_kwargs ) if "encoder_hidden_states" not in model_kwargs: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 02efc7c344f..a3b3de33fa6 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -534,7 +534,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None: raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + "You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same " + "time, and must specify either one" ) legacy_processing = False diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d552bf73442..545b696d673 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -29,6 +29,7 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed from transformers.testing_utils import ( is_flaky, require_accelerate, + require_flash_attn, require_optimum_quanto, require_torch, require_torch_gpu, @@ -136,6 +137,34 @@ class GenerationTesterMixin: return config, filtered_inputs_dict + def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5): + """ + Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in + the following siturations: + 1. The sequences are the same + 2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical + """ + # scores doesn't include data regarding decoder input tokens + decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores) + output_matches = output_1.sequences == output_2.sequences + has_matching_outputs = output_matches.all() + has_matching_scores = None + if not has_matching_outputs: + for batch_idx in range(output_1.sequences.shape[0]): + batch_matches = output_matches[batch_idx] + if batch_matches.all(): + continue + first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False + first_mismatch_idx -= decoder_input_length + output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx] + output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx] + has_matching_scores = torch.allclose( + output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol + ) + if not has_matching_scores: + break + self.assertTrue(has_matching_outputs or has_matching_scores) + def _get_logits_processor_kwargs(self, do_sample=False, config=None): logits_processor_kwargs = { "bad_words_ids": [[1, 0]], @@ -426,7 +455,6 @@ class GenerationTesterMixin: def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( @@ -453,13 +481,12 @@ class GenerationTesterMixin: # Retrocompatibility check self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) - self._check_outputs(output_generate, main_input, model.config) + self._check_outputs(output_generate, model.config) @pytest.mark.generate def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] if not hasattr(config, "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -486,7 +513,7 @@ class GenerationTesterMixin: output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] ) - self._check_outputs(output_generate, main_input, model.config, use_cache=True) + self._check_outputs(output_generate, model.config, use_cache=True) @pytest.mark.generate def test_sample_generate(self): @@ -505,7 +532,6 @@ class GenerationTesterMixin: def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() output_generate = self._sample_generate( @@ -533,7 +559,7 @@ class GenerationTesterMixin: # Retrocompatibility check self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) - self._check_outputs(output_generate, main_input, model.config, num_return_sequences=2) + self._check_outputs(output_generate, model.config, num_return_sequences=2) @pytest.mark.generate def test_beam_search_generate(self): @@ -554,7 +580,6 @@ class GenerationTesterMixin: def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -583,14 +608,16 @@ class GenerationTesterMixin: self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self._check_outputs( - output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"] + output_generate, + model.config, + num_return_sequences=beam_kwargs["num_return_sequences"], + num_beams=beam_kwargs["num_beams"], ) @pytest.mark.generate def test_beam_search_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] if not hasattr(config, "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -623,10 +650,10 @@ class GenerationTesterMixin: self._check_outputs( output_generate, - main_input, model.config, use_cache=True, - num_return_sequences=beam_kwargs["num_beams"], + num_return_sequences=beam_kwargs["num_return_sequences"], + num_beams=beam_kwargs["num_beams"], ) @require_accelerate @@ -675,7 +702,6 @@ class GenerationTesterMixin: def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -706,7 +732,10 @@ class GenerationTesterMixin: self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) self._check_outputs( - output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"] + output_generate, + model.config, + num_return_sequences=beam_kwargs["num_return_sequences"], + num_beams=beam_kwargs["num_beams"], ) @pytest.mark.generate @@ -765,7 +794,6 @@ class GenerationTesterMixin: def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_diverse_beam_kwargs() @@ -794,7 +822,10 @@ class GenerationTesterMixin: self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self._check_outputs( - output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"] + output_generate, + model.config, + num_return_sequences=beam_kwargs["num_return_sequences"], + num_beams=beam_kwargs["num_beams"], ) # TODO: @gante check why it is flaky @@ -859,7 +890,6 @@ class GenerationTesterMixin: def test_constrained_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() @@ -899,7 +929,10 @@ class GenerationTesterMixin: self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self._check_outputs( - output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"] + output_generate, + model.config, + num_return_sequences=beam_kwargs["num_return_sequences"], + num_beams=beam_kwargs["num_beams"], ) @pytest.mark.generate @@ -942,7 +975,6 @@ class GenerationTesterMixin: self.skipTest(reason="Won't fix: old model with different cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -968,7 +1000,7 @@ class GenerationTesterMixin: output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] ) - self._check_outputs(output_generate, main_input, model.config, use_cache=True) + self._check_outputs(output_generate, model.config, use_cache=True) @pytest.mark.generate def test_contrastive_generate_low_memory(self): @@ -1064,14 +1096,10 @@ class GenerationTesterMixin: @pytest.mark.generate @parameterized.expand([("random",), ("same",)]) - @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail. def test_assisted_decoding_matches_greedy_search(self, assistant_type): # This test ensures that the assisted generation does not introduce output changes over greedy search. - # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul - # shape differences -- and it may result in a different output. The input shape difference happens in the - # main model, that runs the forward pass with several candidates at once (as opposed to generating one token at - # a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info. - # NOTE (2): It breaks the pattern in the tests above, for multiple reasons: + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info. + # NOTE: It breaks the pattern in the tests above, for multiple reasons: # - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to # prepare the assistant encoder outputs in the main generate body); # - assisted_decoding does not support `use_cache = False` @@ -1100,7 +1128,6 @@ class GenerationTesterMixin: # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) - main_input = inputs_dict[model_class.main_input_name] # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1141,12 +1168,10 @@ class GenerationTesterMixin: output_assisted = model.generate(**generation_kwargs, **inputs_dict) # The two outputs must match and their shape must be as expected - - self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) + self._check_similar_generate_outputs(output_greedy, output_assisted) for output in (output_greedy, output_assisted): - self._check_outputs(output, main_input, model.config, use_cache=True) + self._check_outputs(output, model.config, use_cache=True) - @is_flaky() @pytest.mark.generate def test_prompt_lookup_decoding_matches_greedy_search(self): # This test ensures that the prompt lookup generation does not introduce output changes over greedy search. @@ -1175,7 +1200,6 @@ class GenerationTesterMixin: # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) - main_input = inputs_dict[model_class.main_input_name] # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1208,10 +1232,9 @@ class GenerationTesterMixin: output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict) # The two outputs must match and their shape must be as expected - - self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist()) + self._check_similar_generate_outputs(output_greedy, output_prompt_lookup) for output in (output_greedy, output_prompt_lookup): - self._check_outputs(output, main_input, model.config, use_cache=True) + self._check_outputs(output, model.config, use_cache=True) @pytest.mark.generate def test_dola_decoding_sample(self): @@ -1231,7 +1254,6 @@ class GenerationTesterMixin: # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] # Encoder-decoder models are not supported if config.is_encoder_decoder: @@ -1259,7 +1281,7 @@ class GenerationTesterMixin: "dola_layers": "low", } output_dola = model.generate(**generation_kwargs, **inputs_dict) - self._check_outputs(output_dola, main_input, model.config, use_cache=getattr(config, "use_cache", False)) + self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False)) @pytest.mark.generate def test_assisted_decoding_sample(self): @@ -1289,7 +1311,6 @@ class GenerationTesterMixin: # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) - main_input = inputs_dict[model_class.main_input_name] # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1321,7 +1342,7 @@ class GenerationTesterMixin: } output_assisted = model.generate(**generation_kwargs, **inputs_dict) - self._check_outputs(output_assisted, main_input, config, use_cache=True) + self._check_outputs(output_assisted, config, use_cache=True) @pytest.mark.generate def test_prompt_lookup_decoding_stops_at_eos(self): @@ -1547,75 +1568,93 @@ class GenerationTesterMixin: ) @pytest.mark.generate - @parameterized.expand([(1,), (2,)]) - def test_generate_from_inputs_embeds_decoder_only(self, num_beams): + @parameterized.expand([("greedy", 1), ("beam search", 2)]) + def test_generate_from_inputs_embeds(self, _, num_beams): + """Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc""" # When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids` # if fails, you should probably update the `prepare_inputs_for_generation` function for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - # Ignore: - # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids, - # which would cause a mismatch), - config.pad_token_id = config.eos_token_id = -1 - # b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the - # variable that holds the scaling factor, which is model-dependent) - if hasattr(config, "scale_embedding"): - config.scale_embedding = False - # This test is for decoder-only models (encoder-decoder models have native input embeddings support in the # decoder) if config.is_encoder_decoder: continue + config.is_decoder = True # Skip models without explicit support - config.is_decoder = True model = model_class(config).to(torch_device).eval() if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): continue + # There are a few exception patterns in this test: + # 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed + requires_inputs_ids = any( + model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl"] + ) + # 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex + # than calling the embedding layer with `input_ids`. Subcases of this exception: + # 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag) + if hasattr(config, "scale_embedding"): + config.scale_embedding = False + # 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the + # exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the + # checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images` + pixel_values_is_mutually_exclusive = any( + model_name in model_class.__name__.lower() + for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma"] + ) + if pixel_values_is_mutually_exclusive: + inputs_dict.pop("pixel_values", None) + inputs_dict.pop("pixel_values_videos", None) + inputs_dict.pop("pixel_values_images", None) + # 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds` + has_complex_embeds_computation = any( + model_name in model_class.__name__.lower() for model_name in ["moshi"] + ) + # 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate, + # we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input. + missing_attention_mask = "attention_mask" not in inputs_dict + + # Traditional way of generating text input_ids = inputs_dict.pop("input_ids") generation_kwargs = { "return_dict_in_generate": True, "output_scores": True, "num_beams": num_beams, "do_sample": False, + "max_new_tokens": 5, + "min_new_tokens": 5, # generate exactly 5 tokens } - - # Traditional way of generating text - outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs) + outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict) self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) - # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) + # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output). + # The output of the two calls should be the same. inputs_embeds = model.get_input_embeddings()(input_ids) outputs_from_embeds = model.generate( - input_ids, - inputs_embeds=inputs_embeds, - max_new_tokens=5, - **generation_kwargs, + input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict ) - self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist()) + if not has_complex_embeds_computation: + self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds) - # But if we pass different inputs_embeds, we should get different outputs (the output text may be the + # If we pass different inputs_embeds, we should get different outputs (the output text may be the # same, but the logits will almost surely be different) random_embeds = torch.rand_like(inputs_embeds) outputs_from_rand_embeds = model.generate( - input_ids, - inputs_embeds=random_embeds, - max_new_tokens=5, - **generation_kwargs, + input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict ) for i in range(len(outputs_from_rand_embeds.scores)): self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) - # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same - outputs_from_embeds_wo_ids = model.generate( - inputs_embeds=inputs_embeds, max_new_tokens=5, **generation_kwargs - ) - self.assertListEqual( - outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(), - outputs_from_embeds_wo_ids.sequences.tolist(), - ) + # input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will + # be the same + if not (requires_inputs_ids or missing_attention_mask): + outputs_from_embeds_wo_ids = model.generate( + inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict + ) + outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :] + self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds) @pytest.mark.generate def test_generate_from_inputs_embeds_with_static_cache(self): @@ -1829,10 +1868,8 @@ class GenerationTesterMixin: @pytest.mark.generate def test_generate_with_static_cache(self): """ - Tests if StaticCache works if we set attn_implementation=static when generation. - This doesn't test if generation quality is good, but tests that models with - self._supports_static_cache don't throw an error when generating and return - a StaticCache object at the end. + Tests that generating with static cache give almost same results as with dynamic cache, and the output cache + has the expected shapes """ for model_class in self.all_generative_model_classes: if not model_class._supports_static_cache: @@ -1851,13 +1888,15 @@ class GenerationTesterMixin: model = model_class(config).to(torch_device).eval() generation_kwargs = { - "max_length": None, "max_new_tokens": max_new_tokens, - "cache_implementation": "static", "return_dict_in_generate": True, # Required to return `past_key_values` + "output_scores": True, "use_cache": True, } + static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static") + + # Check 1: The cache shapes must match the expected shapes max_cache_len = seq_length + max_new_tokens config = config.text_config if hasattr(config, "text_config") else config head_dim = ( @@ -1869,12 +1908,14 @@ class GenerationTesterMixin: else config.num_key_value_heads ) num_hidden_layers = config.num_hidden_layers - results = model.generate(**generation_kwargs, **inputs_dict) - cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) - self.assertTrue(isinstance(results.past_key_values, StaticCache)) - self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape) + self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) + self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers) + self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) + + # Check 2: The outputs must be similar to the case with dynamic cache + dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) + self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation) @require_optimum_quanto @pytest.mark.generate @@ -1908,25 +1949,32 @@ class GenerationTesterMixin: with self.assertRaises(ValueError): model.generate(**generation_kwargs, **inputs_dict) + @parameterized.expand( + [ + ("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow" + ("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix + ] + ) @pytest.mark.generate @require_torch_gpu @slow - @is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky - def test_generate_compile_fullgraph(self): + def test_generate_compile(self, _, end_to_end): """ - Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. + Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests + end-to-end compilation and forward pass compilation only. ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ """ for model_class in self.all_generative_model_classes: if not model_class._supports_static_cache: self.skipTest("This model doesn't support static cache") + # TODO (joao) -- fix and enable me :) - if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]): + if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]): self.skipTest("whisper model end-to-end generate compile not yet supported") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() # TODO (joao) -- fix and enable me :) - if config.is_encoder_decoder: + if end_to_end and config.is_encoder_decoder: self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported") model = model_class(config).to(torch_device) @@ -1941,27 +1989,33 @@ class GenerationTesterMixin: generation_kwargs = { "do_sample": False, "max_new_tokens": 10, + "return_dict_in_generate": True, + "output_scores": True, } + # end-to-end works best with dynamic cache, forward compilation works best with static cache + if not end_to_end: + generation_kwargs["cache_implementation"] = "static" - max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"] - config = config.get_text_config() - past_key_values = StaticCache( - config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device - ) - + # get eager + dynamic cache results for future comparison + dynamic_outputs = [] for model_inputs in input_ids_sets: - # eager dynamic cache - output_dynamic = model.generate(model_inputs, **generation_kwargs) + dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs)) - # end-to-end compiled dynamic cache - torch.compiler.reset() - compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") - generation_config = copy.deepcopy(model.generation_config) - generation_config.update(**generation_kwargs) - output_compiled = compiled_generate( - model_inputs, generation_config=generation_config, past_key_values=past_key_values - ) - self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) + # get compiled results + generation_config = copy.deepcopy(model.generation_config) + generation_config.update(**generation_kwargs) + torch.compiler.reset() + if end_to_end: + model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead") + else: + model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") + + compiled_outputs = [] + for model_inputs in input_ids_sets: + compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config)) + + for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs): + self._check_similar_generate_outputs(dynamic_result, compiled_result) @pytest.mark.generate def test_generate_methods_with_num_logits_to_keep(self): @@ -1989,7 +2043,6 @@ class GenerationTesterMixin: self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) @pytest.mark.generate - @is_flaky() # assisted generation tests are flaky (minor fp ops differences) def test_assisted_decoding_with_num_logits_to_keep(self): for model_class in self.all_generative_model_classes: if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): @@ -1998,6 +2051,9 @@ class GenerationTesterMixin: self.skipTest(reason="Stateful models don't support assisted generation") config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.use_cache = True config.is_decoder = True @@ -2010,14 +2066,16 @@ class GenerationTesterMixin: "max_new_tokens": 10, "do_sample": False, "assistant_model": assistant_model, + "return_dict_in_generate": True, + "output_scores": True, } - assistant_model.generation_config.assistant_confidence_threshold = None # Setting num_logits_to_keep at 0 keeps all logits (old behavior) with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0) # By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior) without_all_logits = model.generate(**inputs_dict, **generation_kwargs) - self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + + self._check_similar_generate_outputs(with_all_logits, without_all_logits) @pytest.mark.generate def test_inherits_generation_mixin(self): @@ -2028,14 +2086,21 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - @require_torch_sdpa - @slow - def test_eager_matches_sdpa_generate(self): + def _test_attention_implementation(self, attn_implementation): + """ + Compares the output of generate with the eager attention implementation against other implementations. + NOTE: despite the test logic being the same, different implementations actually need diferent decorators, hence + this separate function. + """ max_new_tokens = 30 + support_flag = { + "sdpa": "_supports_sdpa", + "flash_attention_2": "_supports_flash_attn_2", + } for model_class in self.all_generative_model_classes: - if not model_class._supports_sdpa: - self.skipTest(f"{model_class.__name__} does not support SDPA") + if not getattr(model_class, support_flag[attn_implementation]): + self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`") config, original_inputs_dict = self.prepare_config_and_inputs_for_generate() inputs_dict = {} @@ -2062,17 +2127,9 @@ class GenerationTesterMixin: "do_sample": False, "return_dict_in_generate": True, "output_scores": True, + "use_cache": True, } - model_sdpa = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - ).to(torch_device) - res_sdpa = model_sdpa.generate(**inputs_dict, **generate_kwargs) - del model_sdpa - gc.collect() - model_eager = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, @@ -2083,42 +2140,46 @@ class GenerationTesterMixin: del model_eager gc.collect() - # Eager and SDPA are very similar, but not exactly the same. Because we are using random models, this - # test would be flaky if we only checked the sequences. Two situations in which this test passes: - # 1. The sequences are the same - # 2. The sequences are different, but the scores up until the first mismatch are nearly identical - output_matches = res_eager.sequences == res_sdpa.sequences - has_matching_outputs = output_matches.all() - has_matching_scores = None - if not has_matching_outputs: - input_length = main_input.shape[1] - for batch_idx in range(res_eager.sequences.shape[0]): - batch_matches = output_matches[batch_idx] - if batch_matches.all(): - continue - first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False - first_mismatch_idx -= input_length # scores doesn't include data regarding input tokens - sdpa_first_mismatch_scores = res_sdpa.scores[first_mismatch_idx][batch_idx] - eager_first_mismatch_scores = res_eager.scores[first_mismatch_idx][batch_idx] - has_matching_scores = torch.allclose( - sdpa_first_mismatch_scores, eager_first_mismatch_scores, rtol=1e-3, atol=1e-3 - ) - if not has_matching_scores: - break + model_attn = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation=attn_implementation, + ).to(torch_device) + res_attn = model_attn.generate(**inputs_dict, **generate_kwargs) + del model_attn + gc.collect() - self.assertTrue(has_matching_outputs or has_matching_scores) + self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3) - def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): - # we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image - # so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests` - batch_size = main_input.shape[0] + @pytest.mark.generate + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + """Tests that generate has equivalent outputs with SDPA and eager attention implementations.""" + self._test_attention_implementation("sdpa") + + @pytest.mark.flash_attn_test + @require_flash_attn + @require_torch_gpu + @slow + def test_eager_matches_fa2_generate(self): + """Tests that generate has equivalent outputs with FA2 and eager attention implementations.""" + # TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing, + # check whether we still need the overwrites + self._test_attention_implementation("flash_attention_2") + + def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): + input_batch_size = int(output.sequences.shape[0] / num_return_sequences) + internal_batch_size = ( + input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences + ) seq_length = getattr(self.model_tester, "seq_length", None) seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length) seq_length = getattr(self.model_tester, "text_seq_length", seq_length) config = config.text_config if hasattr(config, "text_config") else config - num_sequences_in_output = batch_size * num_return_sequences gen_len = ( output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length @@ -2129,19 +2190,21 @@ class GenerationTesterMixin: seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) # scores - self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) + self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config) # unprocessed logits - self._check_logits(num_sequences_in_output, output.logits, config=config) + self._check_logits(internal_batch_size, output.logits, config=config) # Attentions if self.has_attentions: if config.is_encoder_decoder: # encoder - self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length) + self._check_encoder_attention_for_generate( + output.encoder_attentions, input_batch_size, config, seq_length + ) # decoder self._check_attentions_for_generate( - num_sequences_in_output, + internal_batch_size, output.decoder_attentions, min_length=1, max_length=output.sequences.shape[-1], @@ -2153,7 +2216,7 @@ class GenerationTesterMixin: attentions = output.attentions if not use_cache else output.attentions[1:] min_length = seq_length if not use_cache else seq_length + 1 self._check_attentions_for_generate( - num_sequences_in_output, + internal_batch_size, attentions=attentions, min_length=min_length, max_length=output.sequences.shape[-1], @@ -2165,12 +2228,12 @@ class GenerationTesterMixin: if config.is_encoder_decoder: # encoder self._check_encoder_hidden_states_for_generate( - output.encoder_hidden_states, batch_size, config, seq_length + output.encoder_hidden_states, input_batch_size, config, seq_length ) # decoder self._check_hidden_states_for_generate( - num_sequences_in_output, + internal_batch_size, output.decoder_hidden_states, min_length=1, max_length=output.sequences.shape[-1], @@ -2182,7 +2245,7 @@ class GenerationTesterMixin: hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] min_length = seq_length if not use_cache else seq_length + 1 self._check_hidden_states_for_generate( - num_sequences_in_output, + internal_batch_size, hidden_states, min_length=min_length, max_length=output.sequences.shape[-1], @@ -2213,7 +2276,7 @@ class GenerationTesterMixin: past_key_values = output.past_key_values past_sequence_length = output.sequences.shape[-1] - 1 self._check_past_key_values_for_generate( - num_sequences_in_output, + internal_batch_size, past_key_values, seq_length=past_sequence_length, config=config, diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index eda51d21199..e4d0df141be 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1532,8 +1532,3 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un @unittest.skip def test_save_load_fast_init_from_base(self): pass - - @unittest.skip(reason="Generate needs input ids") - def test_inputs_embeds_matches_input_ids_with_generate(self): - # generate only works with input ids for bartforcausalLM - pass diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index aa9835d8cd6..25566027742 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -511,11 +511,6 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) - @unittest.skip(reason="Generate needs input ids") - def test_inputs_embeds_matches_input_ids_with_generate(self): - # generate only works with input ids for bertforcausalLM - pass - def test_model_as_decoder_with_default_input_mask(self): # This regression test was failing with PyTorch < 1.3 ( diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index aad26ef147e..2a8e7633ba4 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -16,17 +16,14 @@ import unittest -import pytest import requests from parameterized import parameterized from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, - require_flash_attn, require_read_token, require_torch, - require_torch_gpu, slow, torch_device, ) @@ -329,43 +326,6 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - @require_flash_attn - @require_read_token - @require_torch_gpu - @require_bitsandbytes - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - """ - Overwritting the common test as the test is flaky on tiny models - """ - model = ChameleonForConditionalGeneration.from_pretrained( - "facebook/chameleon-7b", - load_in_4bit=True, - device_map={"": 0}, - ) - - processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") - texts = ["hi", "Hello this is a very long sentence"] - - processor.tokenizer.padding_side = "right" - - inputs = processor(text=texts, return_tensors="pt", padding=True).to(0) - - output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_native = processor.tokenizer.batch_decode(output_native) - - model = ChameleonForConditionalGeneration.from_pretrained( - "facebook/chameleon-7b", - load_in_4bit=True, - attn_implementation="flash_attention_2", - ) - - output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_fa_2 = processor.tokenizer.batch_decode(output_fa_2) - - self.assertListEqual(output_native, output_fa_2) - @unittest.skip("Chameleon forces some token ids to be -inf!") def test_batching_equivalence(self): pass diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index a888bdcd3bc..e8483f8c7c7 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -319,9 +319,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.6] - # used in `test_torch_compile` - _torch_compile_test_ckpt = "google/gemma-2b" - # used in `test_torch_compile_for_training` _torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None @@ -419,51 +416,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_past_key_values_format(self): pass - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - import torch - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # NOTE: Gemma apparently does not support right padding + use_cache with FA2. - dummy_attention_mask[:, -1] = 1 - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 94670803daa..7bca83f96d7 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -78,7 +78,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): test_pruning = False _is_stateful = True model_split_percents = [0.5, 0.6] - _torch_compile_test_ckpt = "google/gemma-2-9b" def setUp(self): self.model_tester = Gemma2ModelTester(self) diff --git a/tests/models/glm/test_modeling_glm.py b/tests/models/glm/test_modeling_glm.py index 32bce7cbfa6..b92c5db815b 100644 --- a/tests/models/glm/test_modeling_glm.py +++ b/tests/models/glm/test_modeling_glm.py @@ -28,7 +28,6 @@ from transformers.testing_utils import ( require_flash_attn, require_torch, require_torch_accelerator, - require_torch_gpu, require_torch_sdpa, slow, torch_device, @@ -306,10 +305,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, test_headmasking = False test_pruning = False - # used in `test_torch_compile` - _torch_compile_test_ckpt = "THUDM/glm-4-9b" - _torch_compile_test_revision = "refs/pr/15" - def setUp(self): self.model_tester = GlmModelTester(self) self.config_tester = ConfigTester(self, config_class=GlmConfig, hidden_size=37) @@ -426,41 +421,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3) - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - """Overwrite the common test as the test is flaky on tiny models.""" - model = GlmForCausalLM.from_pretrained( - "THUDM/glm-4-9b", - device_map={"": 0}, - torch_dtype=torch.bfloat16, - revision="refs/pr/15", - ) - - tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b", revision="refs/pr/15") - tokenizer.padding_side = "right" - - texts = ["hi", "Hello this is a very long sentence"] - inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) - - output_native = model.generate(**inputs, max_new_tokens=15, do_sample=False) - output_native = tokenizer.batch_decode(output_native) - - model = GlmForCausalLM.from_pretrained( - "THUDM/glm-4-9b", - device_map={"": 0}, - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16, - revision="refs/pr/15", - ) - - output_fa_2 = model.generate(**inputs, max_new_tokens=15, do_sample=False) - output_fa_2 = tokenizer.batch_decode(output_fa_2) - - self.assertListEqual(output_native, output_fa_2) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py index 6f6fba50dc1..afc741cd502 100644 --- a/tests/models/gptj/test_modeling_gptj.py +++ b/tests/models/gptj/test_modeling_gptj.py @@ -17,14 +17,9 @@ import datetime import unittest -import pytest - -from transformers import BitsAndBytesConfig, GPTJConfig, is_torch_available +from transformers import GPTJConfig, is_torch_available from transformers.testing_utils import ( - require_bitsandbytes, - require_flash_attn, require_torch, - require_torch_gpu, slow, tooslow, torch_device, @@ -505,44 +500,6 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16) self.assertIsNotNone(model) - @require_flash_attn - @require_torch_gpu - @require_bitsandbytes - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - """ - Overwritting the common test as the test is flaky on tiny models - """ - tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") - - texts = ["hi", "Hello this is a very long sentence"] - expected_outputs = [ - "hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the", - "Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it", - ] - - tokenizer.padding_side = "right" - tokenizer.pad_token = tokenizer.eos_token - - inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) - - quantization_config = BitsAndBytesConfig(load_in_4bit=True) - - model = GPTJForCausalLM.from_pretrained( - "EleutherAI/gpt-j-6b", - device_map={"": 0}, - attn_implementation="flash_attention_2", - revision="float16", - torch_dtype=torch.float16, - quantization_config=quantization_config, - ) - - output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_fa_2 = tokenizer.batch_decode(output_fa_2) - - self.assertListEqual(expected_outputs, output_fa_2) - @require_torch class GPTJModelLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py index 1bcb6641803..97b59f5aa50 100644 --- a/tests/models/granite/test_modeling_granite.py +++ b/tests/models/granite/test_modeling_granite.py @@ -17,12 +17,10 @@ import tempfile import unittest -import pytest from parameterized import parameterized -from transformers import AutoTokenizer, GraniteConfig, is_torch_available, set_seed +from transformers import GraniteConfig, is_torch_available, set_seed from transformers.testing_utils import ( - require_bitsandbytes, require_flash_attn, require_read_token, require_torch, @@ -303,9 +301,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - # used in `test_torch_compile` - _torch_compile_test_ckpt = "ibm/PowerLM-3b" - def setUp(self): self.model_tester = GraniteModelTester(self) self.config_tester = ConfigTester(self, config_class=GraniteConfig, hidden_size=37) @@ -423,46 +418,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - @require_flash_attn - @require_torch_gpu - @require_bitsandbytes - @pytest.mark.flash_attn_test - @require_read_token - @slow - def test_flash_attn_2_generate_padding_right(self): - """ - Overwritting the common test as the test is flaky on tiny models - """ - model = GraniteForCausalLM.from_pretrained( - "ibm/PowerLM-3b", - load_in_4bit=True, - device_map={"": 0}, - ) - - tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b") - - texts = ["hi", "Hello this is a very long sentence"] - - tokenizer.padding_side = "right" - tokenizer.pad_token = tokenizer.eos_token - - inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) - - output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_native = tokenizer.batch_decode(output_native) - - model = GraniteForCausalLM.from_pretrained( - "ibm/PowerLM-3b", - load_in_4bit=True, - device_map={"": 0}, - attn_implementation="flash_attention_2", - ) - - output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_fa_2 = tokenizer.batch_decode(output_fa_2) - - self.assertListEqual(output_native, output_fa_2) - @require_flash_attn @require_torch_gpu @slow diff --git a/tests/models/granitemoe/test_modeling_granitemoe.py b/tests/models/granitemoe/test_modeling_granitemoe.py index 124ce0c3bb5..f2f76b9fa75 100644 --- a/tests/models/granitemoe/test_modeling_granitemoe.py +++ b/tests/models/granitemoe/test_modeling_granitemoe.py @@ -17,12 +17,10 @@ import tempfile import unittest -import pytest from parameterized import parameterized from transformers import AutoTokenizer, GraniteMoeConfig, is_torch_available, set_seed from transformers.testing_utils import ( - require_bitsandbytes, require_flash_attn, require_read_token, require_torch, @@ -302,9 +300,6 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - # used in `test_torch_compile` - _torch_compile_test_ckpt = "ibm/PowerMoE-3b" - def setUp(self): self.model_tester = GraniteMoeModelTester(self) self.config_tester = ConfigTester(self, config_class=GraniteMoeConfig, hidden_size=37) @@ -422,46 +417,6 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - @require_flash_attn - @require_torch_gpu - @require_bitsandbytes - @pytest.mark.flash_attn_test - @require_read_token - @slow - def test_flash_attn_2_generate_padding_right(self): - """ - Overwritting the common test as the test is flaky on tiny models - """ - model = GraniteMoeForCausalLM.from_pretrained( - "ibm-granite/granitemoe-3b", - load_in_4bit=True, - device_map={"": 0}, - ) - - tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granitemoe-3b") - - texts = ["hi", "Hello this is a very long sentence"] - - tokenizer.padding_side = "right" - tokenizer.pad_token = tokenizer.eos_token - - inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) - - output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_native = tokenizer.batch_decode(output_native) - - model = GraniteMoeForCausalLM.from_pretrained( - "ibm-granite/granitemoe-3b", - load_in_4bit=True, - device_map={"": 0}, - attn_implementation="flash_attention_2", - ) - - output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_fa_2 = tokenizer.batch_decode(output_fa_2) - - self.assertListEqual(output_native, output_fa_2) - @require_flash_attn @require_torch_gpu @slow diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index c2f0ef8ccd0..d19d10932bf 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -770,13 +770,6 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni def test_custom_4d_attention_mask(self): pass - @unittest.skip( - reason="IDEFICS has specific requirements for working with inputs embeds like passing also the ids and pixels" - ) - @parameterized.expand([(1,), (2,)]) - def test_generate_from_inputs_embeds_decoder_only(self, num_beams): - pass - @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") def test_generate_compile_fullgraph(self): pass diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 854b8b93457..042fecf4bd2 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -20,7 +20,6 @@ import tempfile import unittest from io import BytesIO -import pytest import requests from transformers import ( @@ -420,50 +419,6 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest def test_flash_attn_2_fp32_ln(self): pass - @pytest.mark.generate - def test_generate_from_inputs_embeds_decoder_only(self): - # overwrite because IDEFICS needs ids and embeds at the input to be not None - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.prepare_config_and_inputs_for_generate() - - # Ignore: - # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids, - # which would cause a mismatch), - config.pad_token_id = config.eos_token_id = -1 - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - input_ids = inputs_dict.pop("input_ids") - - # Traditional way of generating text - outputs_from_ids = model.generate( - input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True - ) - self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) - - # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) - inputs_embeds = model.get_input_embeddings()(input_ids) - outputs_from_embeds = model.generate( - input_ids, - inputs_embeds=inputs_embeds, - max_new_tokens=5, - return_dict_in_generate=True, - output_scores=True, - ) - self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist()) - - # But if we pass different inputs_embeds, we should get different outputs (the output text may be the - # same, but the logits will almost surely be different) - random_embeds = torch.rand_like(inputs_embeds) - outputs_from_rand_embeds = model.generate( - input_ids, - inputs_embeds=random_embeds, - max_new_tokens=5, - return_dict_in_generate=True, - output_scores=True, - ) - for i in range(len(outputs_from_rand_embeds.scores)): - self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) - # We need to override as we need to prepare such that the image token is the last token def test_resize_tokens_embeddings(self): (original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/idefics3/test_modeling_idefics3.py b/tests/models/idefics3/test_modeling_idefics3.py index f0366e7b539..5dc352d22fe 100644 --- a/tests/models/idefics3/test_modeling_idefics3.py +++ b/tests/models/idefics3/test_modeling_idefics3.py @@ -19,7 +19,6 @@ import gc import unittest from io import BytesIO -import pytest import requests from transformers import ( @@ -180,10 +179,6 @@ class Idefics3ModelTest(ModelTesterMixin, unittest.TestCase): def test_inputs_embeds_matches_input_ids(self): pass - @unittest.skip(reason="Model does not support padding right") - def test_flash_attn_2_generate_padding_right(self): - pass - @unittest.skip(reason="Model does not support padding right") def test_flash_attn_2_inference_padding_right(self): pass @@ -337,10 +332,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest def test_inputs_embeds(): pass - @unittest.skip(reason="Model does not support padding right") - def test_flash_attn_2_generate_padding_right(self): - pass - @unittest.skip(reason="Model does not support padding right") def test_flash_attn_2_inference_padding_right(self): pass @@ -367,50 +358,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest def test_flash_attn_2_fp32_ln(self): pass - @pytest.mark.generate - def test_generate_from_inputs_embeds_decoder_only(self): - # overwrite because IDEFICS needs ids and embeds at the input to be not None - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.prepare_config_and_inputs_for_generate() - - # Ignore: - # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids, - # which would cause a mismatch), - config.pad_token_id = config.eos_token_id = -1 - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - input_ids = inputs_dict.pop("input_ids") - - # Traditional way of generating text - outputs_from_ids = model.generate( - input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True - ) - self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) - - # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) - inputs_embeds = model.get_input_embeddings()(input_ids) - outputs_from_embeds = model.generate( - input_ids, - inputs_embeds=inputs_embeds, - max_new_tokens=5, - return_dict_in_generate=True, - output_scores=True, - ) - self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist()) - - # But if we pass different inputs_embeds, we should get different outputs (the output text may be the - # same, but the logits will almost surely be different) - random_embeds = torch.rand_like(inputs_embeds) - outputs_from_rand_embeds = model.generate( - input_ids, - inputs_embeds=random_embeds, - max_new_tokens=5, - return_dict_in_generate=True, - output_scores=True, - ) - for i in range(len(outputs_from_rand_embeds.scores)): - self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) - # We need to override as we need to prepare such that the image token is the last token def test_resize_tokens_embeddings(self): (original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common() @@ -526,31 +473,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) - def test_inputs_embeds_matches_input_ids_with_generate(self): - # overwrite because IDEFICS needs ids and embeds at the input to be not None - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) - pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 - - wte = model.get_input_embeddings() - - input_ids = inputs["input_ids"] - # some models infer position ids/attn mask differently when input ids - # by check if pad_token let's make sure no padding is in input ids - not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1 - input_ids[input_ids == pad_token_id] = not_pad_token_id - del inputs["input_ids"] - inputs_embeds = wte(input_ids) - out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2) - out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2) - - self.assertTrue(torch.allclose(out_embeds, out_ids)) - @require_torch class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 251f293f722..ef0b5831587 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -539,93 +539,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - r""" - Overriding the test_flash_attn_2_generate_padding_right test as the Jamba model, like Mixtral, doesn't support - right padding + use cache with FA2 - """ - import torch - - for model_class in self.all_generative_model_classes: - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - - model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - with self.assertRaises(ValueError): - _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - r""" - Overriding the test_flash_attn_2_generate_use_cache test as the Jamba model, like Mixtral, doesn't support - right padding + use cache with FA2 - """ - import torch - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # NOTE: Jamba does not support right padding + use_cache with FA2. - dummy_attention_mask[:, -1] = 1 - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py index a04d8bba741..dc510f0ff04 100644 --- a/tests/models/jetmoe/test_modeling_jetmoe.py +++ b/tests/models/jetmoe/test_modeling_jetmoe.py @@ -15,7 +15,6 @@ """Testing suite for the PyTorch JetMoe model.""" import gc -import tempfile import unittest import pytest @@ -377,85 +376,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def test_past_key_values_format(self): pass - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - import torch - - for model_class in self.all_generative_model_classes: - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - - model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - with self.assertRaises(ValueError): - _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - import torch - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # NOTE: JetMoe apparently does not support right padding + use_cache with FA2. - dummy_attention_mask[:, -1] = 1 - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index de6c0b15d66..0f0b595d3d2 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -438,12 +438,6 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) - @unittest.skip( - "KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking ipnut args because KOSMOS-2 has `generate()` overwritten" - ) - def test_inputs_embeds_matches_input_ids_with_generate(self): - pass - @slow def test_model_from_pretrained(self): model_name = "microsoft/kosmos-2-patch14-224" diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 824337d8bdd..375ec1dd3e6 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -26,7 +26,6 @@ from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_avail from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( backend_empty_cache, - require_bitsandbytes, require_flash_attn, require_read_token, require_torch, @@ -316,9 +315,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] - # used in `test_torch_compile` - _torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf" - # used in `test_torch_compile_for_training` _torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None @@ -585,43 +581,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi with self.assertRaises(KeyError): config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" - @require_flash_attn - @require_torch_gpu - @require_bitsandbytes - @pytest.mark.flash_attn_test - @require_read_token - @slow - def test_flash_attn_2_generate_padding_right(self): - """ - Overwritting the common test as the test is flaky on tiny models - """ - model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", - load_in_4bit=True, - device_map={"": 0}, - ) - - tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - - texts = ["hi", "Hello this is a very long sentence"] - - tokenizer.padding_side = "right" - tokenizer.pad_token = tokenizer.eos_token - - inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) - - output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_native = tokenizer.batch_decode(output_native) - - model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2" - ) - - output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_fa_2 = tokenizer.batch_decode(output_fa_2) - - self.assertListEqual(output_native, output_fa_2) - @require_flash_attn @require_torch_gpu @slow diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 1a8cf047745..9b3a9563b58 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -204,8 +204,8 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix pass @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") - @parameterized.expand([(1,), (2,)]) - def test_generate_from_inputs_embeds_decoder_only(self, num_beams): + @parameterized.expand([("greedy", 1), ("beam search", 2)]) + def test_generate_from_inputs_embeds(self, _, num_beams): pass @unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case") @@ -276,12 +276,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) - @unittest.skip( - reason="Mamba2 does not support generating with input embeddings (custom cache_position computation)" - ) - def test_inputs_embeds_matches_input_ids_with_generate(self): - pass - @require_torch @slow diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py index 074dceae155..df0007d666a 100644 --- a/tests/models/mimi/test_modeling_mimi.py +++ b/tests/models/mimi/test_modeling_mimi.py @@ -21,7 +21,6 @@ import unittest import numpy as np from datasets import Audio, load_dataset -from packaging import version from parameterized import parameterized from pytest import mark @@ -745,22 +744,6 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase): def test_sdpa_can_compile_dynamic(self): pass - # For now, Let's focus only on GPU for `torch.compile` - @slow - @require_torch_gpu - def test_torch_compile(self): - if version.parse(torch.__version__) < version.parse("2.3"): - self.skipTest(reason="This test requires torch >= 2.3 to run.") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - n_iter = 3 - for model_class in self.all_model_classes: - model = model_class(config).to(torch_device) - model.forward = torch.compile(model.forward) - for i in range(n_iter): - _ = model(inputs_dict["input_values"].to(torch_device)) - @is_flaky() def test_batching_equivalence(self): super().test_batching_equivalence() diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index f2ee714bcdb..1538735ad78 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -15,7 +15,6 @@ """Testing suite for the PyTorch Mistral model.""" import gc -import tempfile import unittest import pytest @@ -416,85 +415,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_past_key_values_format(self): pass - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - import torch - - for model_class in self.all_generative_model_classes: - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - - model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - with self.assertRaises(ValueError): - _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - import torch - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # NOTE: Mistral apparently does not support right padding + use_cache with FA2. - dummy_attention_mask[:, -1] = 1 - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index b9b5faed851..931bb1f17be 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -14,7 +14,6 @@ # limitations under the License. """Testing suite for the PyTorch Mixtral model.""" -import tempfile import unittest import pytest @@ -415,85 +414,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_past_key_values_format(self): pass - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - import torch - - for model_class in self.all_generative_model_classes: - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - - model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - with self.assertRaises(ValueError): - _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - import torch - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # NOTE: Mixtral apparently does not support right padding + use_cache with FA2. - dummy_attention_mask[:, -1] = 1 - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 3efa7b778fb..5174247b895 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -126,7 +126,6 @@ class MllamaForCausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, unitte all_generative_model_classes = (MllamaForCausalLM,) if is_torch_available() else () test_pruning = False test_head_masking = False - _torch_compile_test_ckpt = "nltpt/Llama-3.2-11B-Vision" def setUp(self): self.model_tester = MllamaText2TextModelTester(self) diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index b77a6ff1036..7d4b855c10d 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -560,7 +560,7 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): return config, input_ids, attention_mask, inputs_dict def prepare_config_and_inputs_for_generate(self, batch_size=2): - config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate() + config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size) # Make sure we only return `input_ids`. # Note that audio_codes will still be generated internally, so the ability to test audio codes is still there. @@ -591,9 +591,11 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): [expected_shape] * len(iter_hidden_states), ) - def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): + def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` - super()._check_outputs(output, input_ids, config, use_cache=True, num_return_sequences=num_return_sequences) + super()._check_outputs( + output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams + ) def _check_hidden_states_for_generate( self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 @@ -655,59 +657,6 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) - @pytest.mark.generate - @parameterized.expand([(1,), (2,)]) - def test_generate_from_inputs_embeds_decoder_only(self, num_beams): - for model_class in self.all_generative_model_classes: - config, input_ids, _, inputs_dict = self._get_input_ids_and_config() - - model = model_class(config).to(torch_device).eval() - generation_kwargs = { - "return_dict_in_generate": True, - "output_scores": True, - "num_beams": num_beams, - "do_sample": False, - } - - # Traditional way of generating text - outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs, **inputs_dict) - self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) - - # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) - inputs_embeds = model.get_input_embeddings()(input_ids) - outputs_from_embeds = model.generate( - input_ids, - inputs_embeds=inputs_embeds, - max_new_tokens=5, - **generation_kwargs, - **inputs_dict, - ) - - # But if we pass different inputs_embeds, we should get different outputs (the output text may be the - # same, but the logits will almost surely be different) - random_embeds = torch.rand_like(inputs_embeds) - outputs_from_rand_embeds = model.generate( - input_ids, - inputs_embeds=random_embeds, - max_new_tokens=5, - **generation_kwargs, - **inputs_dict, - ) - for i in range(len(outputs_from_rand_embeds.scores)): - self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) - - # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same - outputs_from_embeds_wo_ids = model.generate( - inputs_embeds=inputs_embeds, - max_new_tokens=5, - **generation_kwargs, - **inputs_dict, - ) - self.assertListEqual( - outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(), - outputs_from_embeds_wo_ids.sequences.tolist(), - ) - @unittest.skip(reason="Continuing from past key values is not straightforward as we're dealing with 3 inputs") def test_generate_continue_from_past_key_values(self): pass diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index 20412da2e1d..1628d3a5893 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -576,9 +576,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # The small MT5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] - # used in `test_torch_compile` - _torch_compile_test_ckpt = "google/mt5-small" - def setUp(self): self.model_tester = MT5ModelTester(self) self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37) diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 346ad60debe..963cace28d6 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -450,144 +450,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding - def test_flash_attn_2_generate_left_padding(self): - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # make sure we do left padding - dummy_attention_mask[:, :-1] = 0 - dummy_attention_mask[:, -1:] = 1 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right - def test_flash_attn_2_generate_padding_right(self): - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # make sure we do right padding - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache - def test_flash_attn_2_generate_use_cache(self): - max_new_tokens = 30 - - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow @@ -1585,149 +1447,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding - def test_flash_attn_2_generate_left_padding(self): - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask") - if dummy_attention_mask is None: - dummy_attention_mask = torch.ones_like(dummy_input) - - # make sure we do left padding - dummy_attention_mask[:, :-1] = 0 - dummy_attention_mask[:, -1:] = 1 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right - def test_flash_attn_2_generate_padding_right(self): - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask") - if dummy_attention_mask is None: - dummy_attention_mask = torch.ones_like(dummy_input) - # make sure we do right padding - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache - def test_flash_attn_2_generate_use_cache(self): - max_new_tokens = 30 - - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): if not self.has_attentions: diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index f3b6be0ac65..957db9f23b0 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -1437,149 +1437,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding - def test_flash_attn_2_generate_left_padding(self): - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask") - if dummy_attention_mask is None: - dummy_attention_mask = torch.ones_like(dummy_input) - - # make sure we do left padding - dummy_attention_mask[:, :-1] = 0 - dummy_attention_mask[:, -1:] = 1 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right - def test_flash_attn_2_generate_padding_right(self): - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask") - if dummy_attention_mask is None: - dummy_attention_mask = torch.ones_like(dummy_input) - # make sure we do right padding - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache - def test_flash_attn_2_generate_use_cache(self): - max_new_tokens = 30 - - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): if not self.has_attentions: diff --git a/tests/models/nemotron/test_modeling_nemotron.py b/tests/models/nemotron/test_modeling_nemotron.py index 13adfe1e579..37a581a3386 100644 --- a/tests/models/nemotron/test_modeling_nemotron.py +++ b/tests/models/nemotron/test_modeling_nemotron.py @@ -92,8 +92,6 @@ class NemotronModelTest(GemmaModelTest): test_pruning = False fx_compatible = False - # used in `test_torch_compile` - _torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf" # used in `test_torch_compile_for_training` _torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 95ae59dfc08..1d96b9c338f 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -346,10 +346,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip(reason="TODO (@joao): fix me -- failing to produce similar results") - def test_static_cache_matches_dynamic(self): - pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") def test_flash_attn_2_fp32_ln(self): pass diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index c17f69a4998..eae6789bef2 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -17,15 +17,11 @@ import unittest -import pytest from parameterized import parameterized from transformers import PhiConfig, is_torch_available, set_seed from transformers.testing_utils import ( - require_bitsandbytes, - require_flash_attn, require_torch, - require_torch_gpu, slow, torch_device, ) @@ -468,43 +464,6 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, torch.testing.assert_close(ntk_sin_long, original_sin_long) self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - @require_flash_attn - @require_torch_gpu - @require_bitsandbytes - @pytest.mark.flash_attn_test - @slow - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1 - def test_flash_attn_2_generate_padding_right(self): - """ - Overwritting the common test as the test is flaky on tiny models - """ - model = PhiForCausalLM.from_pretrained( - "microsoft/phi-1", - load_in_4bit=True, - device_map={"": 0}, - ) - - tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1") - - texts = ["hi", "Hello this is a very long sentence"] - - tokenizer.padding_side = "right" - tokenizer.pad_token = tokenizer.eos_token - - inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) - - output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_native = tokenizer.batch_decode(output_native) - - model = PhiForCausalLM.from_pretrained( - "microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2" - ) - - output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) - output_fa_2 = tokenizer.batch_decode(output_fa_2) - - self.assertListEqual(output_native, output_fa_2) - @slow @require_torch diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 4e57f8e0f00..f51dc2e0a5e 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -15,7 +15,6 @@ """Testing suite for the PyTorch Qwen2 model.""" import gc -import tempfile import unittest import pytest @@ -428,85 +427,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_past_key_values_format(self): pass - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - import torch - - for model_class in self.all_generative_model_classes: - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - - model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - with self.assertRaises(ValueError): - _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - import torch - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # NOTE: Qwen2 apparently does not support right padding + use_cache with FA2. - dummy_attention_mask[:, -1] = 1 - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index c545e882fae..abc7b57919b 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -15,7 +15,6 @@ """Testing suite for the PyTorch Qwen2MoE model.""" import gc -import tempfile import unittest import pytest @@ -453,85 +452,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM def test_past_key_values_format(self): pass - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - import torch - - for model_class in self.all_generative_model_classes: - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - - model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - with self.assertRaises(ValueError): - _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - import torch - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # NOTE: Qwen2Moe apparently does not support right padding + use_cache with FA2. - dummy_attention_mask[:, -1] = 1 - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index e1cd715f8f1..a3272853a78 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -301,10 +301,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas def test_feed_forward_chunking(self): pass - @unittest.skip(reason="Generate needs input ids") - def test_inputs_embeds_matches_input_ids_with_generate(self): - pass - @unittest.skip(reason="CPU offload is not yet supported") def test_cpu_offload(self): pass diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index 542955f9fa4..985115d7707 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -420,10 +420,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT def test_initialization(self): pass - @unittest.skip(reason="RecurrentGemma does not support generating with input embeddings (missing position_ids)") - def test_inputs_embeds_matches_input_ids_with_generate(self): - pass - @require_torch_accelerator @slow diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index 32d28143d72..df743f132c1 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -14,7 +14,6 @@ # limitations under the License. """Testing suite for the PyTorch Starcoder2 model.""" -import tempfile import unittest import pytest @@ -404,85 +403,6 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste def test_past_key_values_format(self): pass - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - import torch - - for model_class in self.all_generative_model_classes: - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - - model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - with self.assertRaises(ValueError): - _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - import torch - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # NOTE: Starcoder2 apparently does not support right padding + use_cache with FA2. - dummy_attention_mask[:, -1] = 1 - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 68dd5a52b3d..b0341639076 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -580,9 +580,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # The small T5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] - # used in `test_torch_compile` - _torch_compile_test_ckpt = "google-t5/t5-small" - def setUp(self): self.model_tester = T5ModelTester(self) self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37) diff --git a/tests/models/umt5/test_modeling_umt5.py b/tests/models/umt5/test_modeling_umt5.py index ec4c1d019b6..377668851c5 100644 --- a/tests/models/umt5/test_modeling_umt5.py +++ b/tests/models/umt5/test_modeling_umt5.py @@ -317,9 +317,6 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin # The small UMT5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] - # used in `test_torch_compile` - _torch_compile_test_ckpt = "google/umt5-small" - def setUp(self): self.model_tester = UMT5ModelTester(self) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index b24c577a16e..12aedaca8cf 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1574,59 +1574,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ) assert isinstance(pred_ids, expected_output_type) - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_reuse_cache(self): - max_new_tokens = 2 - for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name][..., :10] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # run generate once to get filled cache - output = model.generate( - dummy_input, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - return_dict_in_generate=True, - ) - past_key_values = output.past_key_values - - # Try to continue generation from where we left, given that we have more than 1 new token to process - # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model - _ = model.generate( - dummy_input, - decoder_input_ids=output.sequences, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - past_key_values=past_key_values, - ) - def test_labels_sequence_max_length_correct(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -3961,11 +3908,6 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, # generate only works with input ids for whisper pass - @unittest.skip(reason="Generate needs input ids") - def test_inputs_embeds_matches_input_ids_with_generate(self): - # generate only works with input ids for whisper - pass - @unittest.skip(reason="Decoder can't keep attention grads") def test_retain_grad_hidden_states_attentions(self): return @@ -3974,18 +3916,6 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, def test_save_load_fast_init_from_base(self): pass - @unittest.skip( - reason="FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test" - ) - def test_flash_attn_2_generate_reuse_cache(self): - pass - - @unittest.skip( - "Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test" - ) - def test_flash_attn_2_generate_padding_right(self): - pass - @unittest.skip( "Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test" ) diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index c0a8020bedd..a6dd516f98a 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -542,93 +542,6 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_padding_right(self): - r""" - Overriding the test_flash_attn_2_generate_padding_right test as the Zamba model, like Mixtral, doesn't support - right padding + use cache with FA2 - """ - import torch - - for model_class in self.all_generative_model_classes: - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - - model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - with self.assertRaises(ValueError): - _ = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - r""" - Overriding the test_flash_attn_2_generate_use_cache test as the Zamba model, like Mixtral, doesn't support - right padding + use cache with FA2 - """ - import torch - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # NOTE: Zamba does not support right padding + use_cache with FA2. - dummy_attention_mask[:, -1] = 1 - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d88b0dc5f02..e2719d8cf1b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -22,7 +22,6 @@ import os.path import random import re import tempfile -import time import warnings from collections import defaultdict from contextlib import contextmanager @@ -37,10 +36,7 @@ import transformers from transformers import ( AutoModel, AutoModelForCausalLM, - AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, - AutoTokenizer, - GenerationConfig, PretrainedConfig, PreTrainedModel, is_torch_available, @@ -86,7 +82,6 @@ from transformers.testing_utils import ( require_deepspeed, require_flash_attn, require_non_xpu, - require_read_token, require_safetensors, require_torch, require_torch_accelerator, @@ -3000,71 +2995,6 @@ class ModelTesterMixin: )[0] self.assertTrue(torch.allclose(out_embeds, out_ids)) - def test_inputs_embeds_matches_input_ids_with_generate(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - for model_class in self.all_generative_model_classes: - if model_class.__name__ not in [ - *get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES), - *get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES), - ]: - continue - - model = model_class(config) - model.to(torch_device) - model.eval() - - model_forward_args = inspect.signature(model.forward).parameters - if any(argument not in model_forward_args for argument in ["inputs_embeds", "position_ids"]): - self.skipTest(reason="This model doesn't use `inputs_embeds` or `position_ids`.") - has_inputs_embeds_forwarding = "inputs_embeds" in set( - inspect.signature(model.prepare_inputs_for_generation).parameters.keys() - ) - if not has_inputs_embeds_forwarding: - self.skipTest(reason="This model doesn't support `inputs_embeds` passed to `generate`.") - inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) - pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 - - # VLMs can't generate with embeds and pixels at the same time. We expect the user to pass merged - # embeds already - if model_class.__name__ in get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES): - inputs.pop("pixel_values", None) - inputs.pop("pixel_values_videos", None) - inputs.pop("pixel_values_images", None) - - wte = model.get_input_embeddings() - if not self.is_encoder_decoder: - input_ids = inputs["input_ids"] - # some models infer position ids/attn mask differently when input ids - # by check if pad_token let's make sure no padding is in input ids - not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1 - input_ids[input_ids == pad_token_id] = not_pad_token_id - del inputs["input_ids"] - inputs_embeds = wte(input_ids) - out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)[:, -2:] - out_embeds = model.generate(inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2) - else: - encoder_input_ids = inputs["input_ids"] - decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) - encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) - decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) - del inputs["input_ids"] - inputs.pop("decoder_input_ids", None) - inputs_embeds = wte(encoder_input_ids) - decoder_inputs_embeds = wte(decoder_input_ids) - out_ids = model.generate( - input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs, max_new_tokens=2 - )[:, -2:] - out_embeds = model.generate( - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - **inputs, - max_new_tokens=2, - ) - # NOTE: this test changes the order of FP ops, there may be tiny differences in the output - number_of_different_tokens = (out_ids != out_embeds).sum() - max_differences = int(out_ids.shape[0] * out_ids.shape[1] * 0.1) - self.assertTrue(number_of_different_tokens <= max_differences) # accept up to 10% mismatch - @require_non_xpu @require_torch_multi_gpu def test_multi_gpu_data_parallel_forward(self): @@ -3857,102 +3787,6 @@ class ModelTesterMixin: assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - @is_flaky() - def test_flash_attn_2_generate_left_padding(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # make sure we do left padding - dummy_attention_mask[:, :-1] = 0 - dummy_attention_mask[:, -1:] = 1 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @is_flaky() - @slow - def test_flash_attn_2_generate_padding_right(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # make sure we do right padding - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - def test_attn_implementation_composite_models(self): """ Tests if composite models can receive a dict object as attn_implementation, where each key should be @@ -4525,65 +4359,6 @@ class ModelTesterMixin: torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4) ) - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - def test_flash_attn_2_generate_use_cache(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - max_new_tokens = 30 - - for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - - # Generate with one batch only to test generation when attention mask will be None - # when real inputs are used, because there is no padding. See issue #32237 for more - dummy_input = dummy_input[:1, ...] - dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...]) - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -4640,62 +4415,6 @@ class ModelTesterMixin: if not has_fa2: raise ValueError("The FA2 model should have FA2 layers") - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - def test_flash_attn_2_generate_reuse_cache(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - max_new_tokens = 2 - for model_class in self.all_generative_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # run generate once to get filled cache - output = model.generate( - dummy_input, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - return_dict_in_generate=True, - ) - past_key_values = output.past_key_values - - # Try to continue generation from where we left, given that we have more than 1 new token to process - # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model - dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1) - _ = model.generate( - dummy_input_updated, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - past_key_values=past_key_values, - ) - @require_flash_attn @require_torch_gpu @require_bitsandbytes @@ -4999,82 +4718,6 @@ class ModelTesterMixin: normalized_1 = F.softmax(out_shared_prefix_last_tokens) torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) - def test_static_cache_matches_dynamic(self): - """ - Tests that generating with static cache give almost same results as with dynamic cache. - This test does not compile the model and check only logits similarity for numerical precision - errors. - """ - if len(self.all_generative_model_classes) == 0: - self.skipTest( - reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks" - ) - for model_class in self.all_generative_model_classes: - if not model_class._supports_static_cache: - self.skipTest(f"{model_class.__name__} does not support static cache") - - if not model_class._supports_cache_class: - self.skipTest(f"{model_class.__name__} does not support cache class") - - config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0: - self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test") - - model = model_class(config).to(device=torch_device, dtype=torch.float32) - model.eval() - - dynamic_out = model.generate( - **inputs, do_sample=False, max_new_tokens=10, output_logits=True, return_dict_in_generate=True - ) - static_out = model.generate( - **inputs, - do_sample=False, - max_new_tokens=10, - cache_implementation="static", - output_logits=True, - return_dict_in_generate=True, - ) - self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4)) - - # For now, Let's focus only on GPU for `torch.compile` - @slow - @require_torch_accelerator - @require_read_token - def test_torch_compile(self): - if version.parse(torch.__version__) < version.parse("2.3"): - self.skipTest(reason="This test requires torch >= 2.3 to run.") - torch.compiler.reset() - if not hasattr(self, "_torch_compile_test_ckpt"): - self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.") - ckpt = self._torch_compile_test_ckpt - revision = "main" if not hasattr(self, "_torch_compile_test_revision") else self._torch_compile_test_revision - - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - batch_size = 1 - n_iter = 3 - - tokenizer = AutoTokenizer.from_pretrained(ckpt) - if self.is_encoder_decoder: - model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( - torch_device - ) - else: - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( - torch_device - ) - - model.generation_config.max_new_tokens = 4 - - model.generation_config.cache_implementation = "static" - model.forward = torch.compile(model.forward, mode="reduce-overhead") - - input_text = "Why dogs are cute?" - input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device) - - for i in range(n_iter): - _ = model.generate(**input_ids, do_sample=False) - @slow @require_torch_gpu def test_torch_compile_for_training(self): @@ -5118,74 +4761,6 @@ class ModelTesterMixin: for name, param in model._orig_mod.named_parameters(): torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4) - @slow - @require_torch_gpu # Testing cuda graphs. - @require_read_token - def test_compile_cuda_graph_time(self): - if version.parse(torch.__version__) < version.parse("2.3"): - self.skipTest(reason="This test requires torch >= 2.3 to run.") - - # TODO felix: All models supporting `StaticCache` or `torch.compile` should be tested. - # At the moment, only llama, gemma and gemma2 are tested here! - if not hasattr(self, "_torch_compile_test_ckpt"): - self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.") - ckpt = self._torch_compile_test_ckpt - revision = "main" if not hasattr(self, "_torch_compile_test_revision") else self._torch_compile_test_revision - - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - tokenizer = AutoTokenizer.from_pretrained(ckpt) - if self.is_encoder_decoder: - model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( - torch_device - ) - else: - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( - torch_device - ) - - cache_implementation = "static" - if model.config.model_type == "gemma2": - cache_implementation = "hybrid" - - new_tokens = 50 - gen_config = GenerationConfig( - max_new_tokens=new_tokens, - min_new_tokens=new_tokens, - use_cache=True, - pad_token_id=tokenizer.pad_token_id, - num_beams=1, - do_sample=False, - eos_token_id=None, # This is required for min_new_tokens to actually have an effect. - ) - model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect. - - model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) - - inp = tokenizer("Why cats are cute?", return_tensors="pt").to(torch_device) - - # First run: the first run warms up each graph, which does things like CuBlas or Triton benchmarking - start = time.perf_counter() - _ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation) - end = time.perf_counter() - graph_warmup_time = end - start - - # Second run: CUDA Graph recording, and replays it - start = time.perf_counter() - _ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation) - end = time.perf_counter() - record_time = end - start - - # Finally: we hit the optimized, CUDA Graph replay path - start = time.perf_counter() - _ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation) - end = time.perf_counter() - opt_time = end - start - - # For the recording step, we expect only two cuda graphs and this step should be much faster than the first. - self.assertTrue(record_time < 0.15 * graph_warmup_time) - self.assertTrue(opt_time < record_time) - def test_forward_with_num_logits_to_keep(self): for model_class in self.all_generative_model_classes: if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):