Tests: move generate tests to the right mixin and delete redundant tests (#34464)

* tmp commit

* tmp commit

* cull overwrites of deleted tests

* typo

* more specific docstring

* make fixup

* parameterize at the top?

* correction

* more deletions :D

* tmp commit

* for VLMs too

* fix _check_outputs

* test nit

* make fixup

* fix another flaky

* test_generate_from_inputs_embeds -- handle missing attention mask
This commit is contained in:
Joao Gante 2024-10-30 10:59:08 +00:00 committed by GitHub
parent 913330ca9f
commit 8a734ea2c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 265 additions and 2348 deletions

View File

@ -378,10 +378,14 @@ class GenerationMixin:
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
@ -414,7 +418,7 @@ class GenerationMixin:
for model_input_name in ["position_ids", "token_type_ids"]:
model_input = kwargs.get(model_input_name)
if model_input is not None:
if past_key_values:
if past_key_values is not None:
model_input = model_input[:, -input_ids.shape[1] :]
model_input = model_input.clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input
@ -568,27 +572,34 @@ class GenerationMixin:
def _prepare_attention_mask_for_generation(
self,
inputs: torch.Tensor,
pad_token_id: Optional[torch.Tensor],
eos_token_id: Optional[torch.Tensor],
inputs_tensor: torch.Tensor,
generation_config: GenerationConfig,
model_kwargs: Dict[str, Any],
) -> torch.LongTensor:
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
# `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
inputs_tensor = model_kwargs["input_ids"]
# No information for attention mask inference -> return default attention mask
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
if pad_token_id is None:
return default_attention_mask
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if not is_input_ids:
return default_attention_mask
is_pad_token_in_inputs = (pad_token_id is not None) and (
isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any()
isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs.ne(pad_token_id).long()
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
@ -2020,7 +2031,7 @@ class GenerationMixin:
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
inputs_tensor, generation_config, model_kwargs
)
elif kwargs_has_attention_mask:
# TODO (joao): generalize this check with other types of inputs

View File

@ -911,7 +911,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
legacy_processing = False

View File

@ -424,7 +424,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
legacy_processing = False

View File

@ -657,7 +657,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values/pixel_values_videos and inputs_embeds at the same time, and must specify either one"
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None:

View File

@ -1562,7 +1562,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
input_ids, generation_config, model_kwargs
)
# 5. Prepare `max_length` depending on other stopping criteria.
@ -2578,7 +2578,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
inputs_tensor, generation_config, model_kwargs
)
if "encoder_outputs" not in model_kwargs:

View File

@ -1484,7 +1484,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin):
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
input_ids, generation_config, model_kwargs
)
# 5. Prepare `max_length` depending on other stopping criteria.
@ -2425,7 +2425,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
inputs_tensor, generation_config, model_kwargs
)
if "encoder_hidden_states" not in model_kwargs:

View File

@ -534,7 +534,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
"You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
"time, and must specify either one"
)
legacy_processing = False

View File

@ -29,6 +29,7 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_flash_attn,
require_optimum_quanto,
require_torch,
require_torch_gpu,
@ -136,6 +137,34 @@ class GenerationTesterMixin:
return config, filtered_inputs_dict
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
"""
Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in
the following siturations:
1. The sequences are the same
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
"""
# scores doesn't include data regarding decoder input tokens
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
output_matches = output_1.sequences == output_2.sequences
has_matching_outputs = output_matches.all()
has_matching_scores = None
if not has_matching_outputs:
for batch_idx in range(output_1.sequences.shape[0]):
batch_matches = output_matches[batch_idx]
if batch_matches.all():
continue
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
first_mismatch_idx -= decoder_input_length
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
has_matching_scores = torch.allclose(
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
)
if not has_matching_scores:
break
self.assertTrue(has_matching_outputs or has_matching_scores)
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {
"bad_words_ids": [[1, 0]],
@ -426,7 +455,6 @@ class GenerationTesterMixin:
def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
@ -453,13 +481,12 @@ class GenerationTesterMixin:
# Retrocompatibility check
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self._check_outputs(output_generate, main_input, model.config)
self._check_outputs(output_generate, model.config)
@pytest.mark.generate
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@ -486,7 +513,7 @@ class GenerationTesterMixin:
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
)
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
self._check_outputs(output_generate, model.config, use_cache=True)
@pytest.mark.generate
def test_sample_generate(self):
@ -505,7 +532,6 @@ class GenerationTesterMixin:
def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
@ -533,7 +559,7 @@ class GenerationTesterMixin:
# Retrocompatibility check
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
self._check_outputs(output_generate, main_input, model.config, num_return_sequences=2)
self._check_outputs(output_generate, model.config, num_return_sequences=2)
@pytest.mark.generate
def test_beam_search_generate(self):
@ -554,7 +580,6 @@ class GenerationTesterMixin:
def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
@ -583,14 +608,16 @@ class GenerationTesterMixin:
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
output_generate,
model.config,
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
@pytest.mark.generate
def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@ -623,10 +650,10 @@ class GenerationTesterMixin:
self._check_outputs(
output_generate,
main_input,
model.config,
use_cache=True,
num_return_sequences=beam_kwargs["num_beams"],
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
@require_accelerate
@ -675,7 +702,6 @@ class GenerationTesterMixin:
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
@ -706,7 +732,10 @@ class GenerationTesterMixin:
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self._check_outputs(
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
output_generate,
model.config,
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
@pytest.mark.generate
@ -765,7 +794,6 @@ class GenerationTesterMixin:
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_diverse_beam_kwargs()
@ -794,7 +822,10 @@ class GenerationTesterMixin:
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
output_generate,
model.config,
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
# TODO: @gante check why it is flaky
@ -859,7 +890,6 @@ class GenerationTesterMixin:
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
@ -899,7 +929,10 @@ class GenerationTesterMixin:
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
output_generate,
model.config,
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
@pytest.mark.generate
@ -942,7 +975,6 @@ class GenerationTesterMixin:
self.skipTest(reason="Won't fix: old model with different cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@ -968,7 +1000,7 @@ class GenerationTesterMixin:
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
)
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
self._check_outputs(output_generate, model.config, use_cache=True)
@pytest.mark.generate
def test_contrastive_generate_low_memory(self):
@ -1064,14 +1096,10 @@ class GenerationTesterMixin:
@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
# shape differences -- and it may result in a different output. The input shape difference happens in the
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
# NOTE: It breaks the pattern in the tests above, for multiple reasons:
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_decoding does not support `use_cache = False`
@ -1100,7 +1128,6 @@ class GenerationTesterMixin:
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[model_class.main_input_name]
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@ -1141,12 +1168,10 @@ class GenerationTesterMixin:
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
# The two outputs must match and their shape must be as expected
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
self._check_similar_generate_outputs(output_greedy, output_assisted)
for output in (output_greedy, output_assisted):
self._check_outputs(output, main_input, model.config, use_cache=True)
self._check_outputs(output, model.config, use_cache=True)
@is_flaky()
@pytest.mark.generate
def test_prompt_lookup_decoding_matches_greedy_search(self):
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
@ -1175,7 +1200,6 @@ class GenerationTesterMixin:
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[model_class.main_input_name]
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@ -1208,10 +1232,9 @@ class GenerationTesterMixin:
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
# The two outputs must match and their shape must be as expected
self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
for output in (output_greedy, output_prompt_lookup):
self._check_outputs(output, main_input, model.config, use_cache=True)
self._check_outputs(output, model.config, use_cache=True)
@pytest.mark.generate
def test_dola_decoding_sample(self):
@ -1231,7 +1254,6 @@ class GenerationTesterMixin:
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
# Encoder-decoder models are not supported
if config.is_encoder_decoder:
@ -1259,7 +1281,7 @@ class GenerationTesterMixin:
"dola_layers": "low",
}
output_dola = model.generate(**generation_kwargs, **inputs_dict)
self._check_outputs(output_dola, main_input, model.config, use_cache=getattr(config, "use_cache", False))
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
@pytest.mark.generate
def test_assisted_decoding_sample(self):
@ -1289,7 +1311,6 @@ class GenerationTesterMixin:
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[model_class.main_input_name]
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@ -1321,7 +1342,7 @@ class GenerationTesterMixin:
}
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
self._check_outputs(output_assisted, main_input, config, use_cache=True)
self._check_outputs(output_assisted, config, use_cache=True)
@pytest.mark.generate
def test_prompt_lookup_decoding_stops_at_eos(self):
@ -1547,75 +1568,93 @@ class GenerationTesterMixin:
)
@pytest.mark.generate
@parameterized.expand([(1,), (2,)])
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
@parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams):
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
# b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the
# variable that holds the scaling factor, which is model-dependent)
if hasattr(config, "scale_embedding"):
config.scale_embedding = False
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
# decoder)
if config.is_encoder_decoder:
continue
config.is_decoder = True
# Skip models without explicit support
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
continue
# There are a few exception patterns in this test:
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed
requires_inputs_ids = any(
model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl"]
)
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
# than calling the embedding layer with `input_ids`. Subcases of this exception:
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
if hasattr(config, "scale_embedding"):
config.scale_embedding = False
# 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower()
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma"]
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
inputs_dict.pop("pixel_values_videos", None)
inputs_dict.pop("pixel_values_images", None)
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
has_complex_embeds_computation = any(
model_name in model_class.__name__.lower() for model_name in ["moshi"]
)
# 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
# we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
missing_attention_mask = "attention_mask" not in inputs_dict
# Traditional way of generating text
input_ids = inputs_dict.pop("input_ids")
generation_kwargs = {
"return_dict_in_generate": True,
"output_scores": True,
"num_beams": num_beams,
"do_sample": False,
"max_new_tokens": 5,
"min_new_tokens": 5, # generate exactly 5 tokens
}
# Traditional way of generating text
outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs)
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
# The output of the two calls should be the same.
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
**generation_kwargs,
input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
if not has_complex_embeds_computation:
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
**generation_kwargs,
input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, max_new_tokens=5, **generation_kwargs
)
self.assertListEqual(
outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
outputs_from_embeds_wo_ids.sequences.tolist(),
)
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
# be the same
if not (requires_inputs_ids or missing_attention_mask):
outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
@pytest.mark.generate
def test_generate_from_inputs_embeds_with_static_cache(self):
@ -1829,10 +1868,8 @@ class GenerationTesterMixin:
@pytest.mark.generate
def test_generate_with_static_cache(self):
"""
Tests if StaticCache works if we set attn_implementation=static when generation.
This doesn't test if generation quality is good, but tests that models with
self._supports_static_cache don't throw an error when generating and return
a StaticCache object at the end.
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
has the expected shapes
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
@ -1851,13 +1888,15 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_length": None,
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
"output_scores": True,
"use_cache": True,
}
static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static")
# Check 1: The cache shapes must match the expected shapes
max_cache_len = seq_length + max_new_tokens
config = config.text_config if hasattr(config, "text_config") else config
head_dim = (
@ -1869,12 +1908,14 @@ class GenerationTesterMixin:
else config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
results = model.generate(**generation_kwargs, **inputs_dict)
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(results.past_key_values, StaticCache))
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
# Check 2: The outputs must be similar to the case with dynamic cache
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
@require_optimum_quanto
@pytest.mark.generate
@ -1908,25 +1949,32 @@ class GenerationTesterMixin:
with self.assertRaises(ValueError):
model.generate(**generation_kwargs, **inputs_dict)
@parameterized.expand(
[
("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
]
)
@pytest.mark.generate
@require_torch_gpu
@slow
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
def test_generate_compile_fullgraph(self):
def test_generate_compile(self, _, end_to_end):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
end-to-end compilation and forward pass compilation only.
Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")
# TODO (joao) -- fix and enable me :)
if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
self.skipTest("whisper model end-to-end generate compile not yet supported")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO (joao) -- fix and enable me :)
if config.is_encoder_decoder:
if end_to_end and config.is_encoder_decoder:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
model = model_class(config).to(torch_device)
@ -1941,27 +1989,33 @@ class GenerationTesterMixin:
generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
"return_dict_in_generate": True,
"output_scores": True,
}
# end-to-end works best with dynamic cache, forward compilation works best with static cache
if not end_to_end:
generation_kwargs["cache_implementation"] = "static"
max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"]
config = config.get_text_config()
past_key_values = StaticCache(
config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device
)
# get eager + dynamic cache results for future comparison
dynamic_outputs = []
for model_inputs in input_ids_sets:
# eager dynamic cache
output_dynamic = model.generate(model_inputs, **generation_kwargs)
dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs))
# end-to-end compiled dynamic cache
torch.compiler.reset()
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
output_compiled = compiled_generate(
model_inputs, generation_config=generation_config, past_key_values=past_key_values
)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
# get compiled results
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
torch.compiler.reset()
if end_to_end:
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
else:
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
compiled_outputs = []
for model_inputs in input_ids_sets:
compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config))
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
self._check_similar_generate_outputs(dynamic_result, compiled_result)
@pytest.mark.generate
def test_generate_methods_with_num_logits_to_keep(self):
@ -1989,7 +2043,6 @@ class GenerationTesterMixin:
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
@pytest.mark.generate
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
def test_assisted_decoding_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
@ -1998,6 +2051,9 @@ class GenerationTesterMixin:
self.skipTest(reason="Stateful models don't support assisted generation")
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.use_cache = True
config.is_decoder = True
@ -2010,14 +2066,16 @@ class GenerationTesterMixin:
"max_new_tokens": 10,
"do_sample": False,
"assistant_model": assistant_model,
"return_dict_in_generate": True,
"output_scores": True,
}
assistant_model.generation_config.assistant_confidence_threshold = None
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
@pytest.mark.generate
def test_inherits_generation_mixin(self):
@ -2028,14 +2086,21 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
def _test_attention_implementation(self, attn_implementation):
"""
Compares the output of generate with the eager attention implementation against other implementations.
NOTE: despite the test logic being the same, different implementations actually need diferent decorators, hence
this separate function.
"""
max_new_tokens = 30
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn_2",
}
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
if not getattr(model_class, support_flag[attn_implementation]):
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
inputs_dict = {}
@ -2062,17 +2127,9 @@ class GenerationTesterMixin:
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
"use_cache": True,
}
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
res_sdpa = model_sdpa.generate(**inputs_dict, **generate_kwargs)
del model_sdpa
gc.collect()
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
@ -2083,42 +2140,46 @@ class GenerationTesterMixin:
del model_eager
gc.collect()
# Eager and SDPA are very similar, but not exactly the same. Because we are using random models, this
# test would be flaky if we only checked the sequences. Two situations in which this test passes:
# 1. The sequences are the same
# 2. The sequences are different, but the scores up until the first mismatch are nearly identical
output_matches = res_eager.sequences == res_sdpa.sequences
has_matching_outputs = output_matches.all()
has_matching_scores = None
if not has_matching_outputs:
input_length = main_input.shape[1]
for batch_idx in range(res_eager.sequences.shape[0]):
batch_matches = output_matches[batch_idx]
if batch_matches.all():
continue
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
first_mismatch_idx -= input_length # scores doesn't include data regarding input tokens
sdpa_first_mismatch_scores = res_sdpa.scores[first_mismatch_idx][batch_idx]
eager_first_mismatch_scores = res_eager.scores[first_mismatch_idx][batch_idx]
has_matching_scores = torch.allclose(
sdpa_first_mismatch_scores, eager_first_mismatch_scores, rtol=1e-3, atol=1e-3
)
if not has_matching_scores:
break
model_attn = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation=attn_implementation,
).to(torch_device)
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
del model_attn
gc.collect()
self.assertTrue(has_matching_outputs or has_matching_scores)
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
# we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image
# so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests`
batch_size = main_input.shape[0]
@pytest.mark.generate
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""Tests that generate has equivalent outputs with SDPA and eager attention implementations."""
self._test_attention_implementation("sdpa")
@pytest.mark.flash_attn_test
@require_flash_attn
@require_torch_gpu
@slow
def test_eager_matches_fa2_generate(self):
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
# TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing,
# check whether we still need the overwrites
self._test_attention_implementation("flash_attention_2")
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
internal_batch_size = (
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
)
seq_length = getattr(self.model_tester, "seq_length", None)
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
config = config.text_config if hasattr(config, "text_config") else config
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
@ -2129,19 +2190,21 @@ class GenerationTesterMixin:
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
# unprocessed logits
self._check_logits(num_sequences_in_output, output.logits, config=config)
self._check_logits(internal_batch_size, output.logits, config=config)
# Attentions
if self.has_attentions:
if config.is_encoder_decoder:
# encoder
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
self._check_encoder_attention_for_generate(
output.encoder_attentions, input_batch_size, config, seq_length
)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
internal_batch_size,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
@ -2153,7 +2216,7 @@ class GenerationTesterMixin:
attentions = output.attentions if not use_cache else output.attentions[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_attentions_for_generate(
num_sequences_in_output,
internal_batch_size,
attentions=attentions,
min_length=min_length,
max_length=output.sequences.shape[-1],
@ -2165,12 +2228,12 @@ class GenerationTesterMixin:
if config.is_encoder_decoder:
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, seq_length
output.encoder_hidden_states, input_batch_size, config, seq_length
)
# decoder
self._check_hidden_states_for_generate(
num_sequences_in_output,
internal_batch_size,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
@ -2182,7 +2245,7 @@ class GenerationTesterMixin:
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_hidden_states_for_generate(
num_sequences_in_output,
internal_batch_size,
hidden_states,
min_length=min_length,
max_length=output.sequences.shape[-1],
@ -2213,7 +2276,7 @@ class GenerationTesterMixin:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
num_sequences_in_output,
internal_batch_size,
past_key_values,
seq_length=past_sequence_length,
config=config,

View File

@ -1532,8 +1532,3 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
@unittest.skip
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
# generate only works with input ids for bartforcausalLM
pass

View File

@ -511,11 +511,6 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
# generate only works with input ids for bertforcausalLM
pass
def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3
(

View File

@ -16,17 +16,14 @@
import unittest
import pytest
import requests
from parameterized import parameterized
from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
@ -329,43 +326,6 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@require_flash_attn
@require_read_token
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b",
load_in_4bit=True,
device_map={"": 0},
)
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
texts = ["hi", "Hello this is a very long sentence"]
processor.tokenizer.padding_side = "right"
inputs = processor(text=texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = processor.tokenizer.batch_decode(output_native)
model = ChameleonForConditionalGeneration.from_pretrained(
"facebook/chameleon-7b",
load_in_4bit=True,
attn_implementation="flash_attention_2",
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = processor.tokenizer.batch_decode(output_fa_2)
self.assertListEqual(output_native, output_fa_2)
@unittest.skip("Chameleon forces some token ids to be -inf!")
def test_batching_equivalence(self):
pass

View File

@ -319,9 +319,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.6]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/gemma-2b"
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None
@ -419,51 +416,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Gemma apparently does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@ -78,7 +78,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
test_pruning = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
_torch_compile_test_ckpt = "google/gemma-2-9b"
def setUp(self):
self.model_tester = Gemma2ModelTester(self)

View File

@ -28,7 +28,6 @@ from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
@ -306,10 +305,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
test_headmasking = False
test_pruning = False
# used in `test_torch_compile`
_torch_compile_test_ckpt = "THUDM/glm-4-9b"
_torch_compile_test_revision = "refs/pr/15"
def setUp(self):
self.model_tester = GlmModelTester(self)
self.config_tester = ConfigTester(self, config_class=GlmConfig, hidden_size=37)
@ -426,41 +421,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
"""Overwrite the common test as the test is flaky on tiny models."""
model = GlmForCausalLM.from_pretrained(
"THUDM/glm-4-9b",
device_map={"": 0},
torch_dtype=torch.bfloat16,
revision="refs/pr/15",
)
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b", revision="refs/pr/15")
tokenizer.padding_side = "right"
texts = ["hi", "Hello this is a very long sentence"]
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=15, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = GlmForCausalLM.from_pretrained(
"THUDM/glm-4-9b",
device_map={"": 0},
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
revision="refs/pr/15",
)
output_fa_2 = model.generate(**inputs, max_new_tokens=15, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(output_native, output_fa_2)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow

View File

@ -17,14 +17,9 @@
import datetime
import unittest
import pytest
from transformers import BitsAndBytesConfig, GPTJConfig, is_torch_available
from transformers import GPTJConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
tooslow,
torch_device,
@ -505,44 +500,6 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16)
self.assertIsNotNone(model)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
texts = ["hi", "Hello this is a very long sentence"]
expected_outputs = [
"hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the",
"Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it",
]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6b",
device_map={"": 0},
attn_implementation="flash_attention_2",
revision="float16",
torch_dtype=torch.float16,
quantization_config=quantization_config,
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(expected_outputs, output_fa_2)
@require_torch
class GPTJModelLanguageGenerationTest(unittest.TestCase):

View File

@ -17,12 +17,10 @@
import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, GraniteConfig, is_torch_available, set_seed
from transformers import GraniteConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
@ -303,9 +301,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "ibm/PowerLM-3b"
def setUp(self):
self.model_tester = GraniteModelTester(self)
self.config_tester = ConfigTester(self, config_class=GraniteConfig, hidden_size=37)
@ -423,46 +418,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@require_read_token
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = GraniteForCausalLM.from_pretrained(
"ibm/PowerLM-3b",
load_in_4bit=True,
device_map={"": 0},
)
tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b")
texts = ["hi", "Hello this is a very long sentence"]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = GraniteForCausalLM.from_pretrained(
"ibm/PowerLM-3b",
load_in_4bit=True,
device_map={"": 0},
attn_implementation="flash_attention_2",
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(output_native, output_fa_2)
@require_flash_attn
@require_torch_gpu
@slow

View File

@ -17,12 +17,10 @@
import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, GraniteMoeConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
@ -302,9 +300,6 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "ibm/PowerMoE-3b"
def setUp(self):
self.model_tester = GraniteMoeModelTester(self)
self.config_tester = ConfigTester(self, config_class=GraniteMoeConfig, hidden_size=37)
@ -422,46 +417,6 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@require_read_token
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = GraniteMoeForCausalLM.from_pretrained(
"ibm-granite/granitemoe-3b",
load_in_4bit=True,
device_map={"": 0},
)
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granitemoe-3b")
texts = ["hi", "Hello this is a very long sentence"]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = GraniteMoeForCausalLM.from_pretrained(
"ibm-granite/granitemoe-3b",
load_in_4bit=True,
device_map={"": 0},
attn_implementation="flash_attention_2",
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(output_native, output_fa_2)
@require_flash_attn
@require_torch_gpu
@slow

View File

@ -770,13 +770,6 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
def test_custom_4d_attention_mask(self):
pass
@unittest.skip(
reason="IDEFICS has specific requirements for working with inputs embeds like passing also the ids and pixels"
)
@parameterized.expand([(1,), (2,)])
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
pass
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
def test_generate_compile_fullgraph(self):
pass

View File

@ -20,7 +20,6 @@ import tempfile
import unittest
from io import BytesIO
import pytest
import requests
from transformers import (
@ -420,50 +419,6 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
input_ids = inputs_dict.pop("input_ids")
# Traditional way of generating text
outputs_from_ids = model.generate(
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -19,7 +19,6 @@ import gc
import unittest
from io import BytesIO
import pytest
import requests
from transformers import (
@ -180,10 +179,6 @@ class Idefics3ModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_generate_padding_right(self):
pass
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_inference_padding_right(self):
pass
@ -337,10 +332,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_inputs_embeds():
pass
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_generate_padding_right(self):
pass
@unittest.skip(reason="Model does not support padding right")
def test_flash_attn_2_inference_padding_right(self):
pass
@ -367,50 +358,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
input_ids = inputs_dict.pop("input_ids")
# Traditional way of generating text
outputs_from_ids = model.generate(
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
@ -526,31 +473,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
def test_inputs_embeds_matches_input_ids_with_generate(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
wte = model.get_input_embeddings()
input_ids = inputs["input_ids"]
# some models infer position ids/attn mask differently when input ids
# by check if pad_token let's make sure no padding is in input ids
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
input_ids[input_ids == pad_token_id] = not_pad_token_id
del inputs["input_ids"]
inputs_embeds = wte(input_ids)
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):

View File

@ -539,93 +539,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
r"""
Overriding the test_flash_attn_2_generate_padding_right test as the Jamba model, like Mixtral, doesn't support
right padding + use cache with FA2
"""
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
r"""
Overriding the test_flash_attn_2_generate_use_cache test as the Jamba model, like Mixtral, doesn't support
right padding + use cache with FA2
"""
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Jamba does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@ -15,7 +15,6 @@
"""Testing suite for the PyTorch JetMoe model."""
import gc
import tempfile
import unittest
import pytest
@ -377,85 +376,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: JetMoe apparently does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@ -438,12 +438,6 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
@unittest.skip(
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking ipnut args because KOSMOS-2 has `generate()` overwritten"
)
def test_inputs_embeds_matches_input_ids_with_generate(self):
pass
@slow
def test_model_from_pretrained(self):
model_name = "microsoft/kosmos-2-patch14-224"

View File

@ -26,7 +26,6 @@ from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_avail
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
@ -316,9 +315,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
@ -585,43 +581,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
with self.assertRaises(KeyError):
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@require_read_token
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
load_in_4bit=True,
device_map={"": 0},
)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
texts = ["hi", "Hello this is a very long sentence"]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(output_native, output_fa_2)
@require_flash_attn
@require_torch_gpu
@slow

View File

@ -204,8 +204,8 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
@parameterized.expand([(1,), (2,)])
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
@parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
@ -276,12 +276,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@unittest.skip(
reason="Mamba2 does not support generating with input embeddings (custom cache_position computation)"
)
def test_inputs_embeds_matches_input_ids_with_generate(self):
pass
@require_torch
@slow

View File

@ -21,7 +21,6 @@ import unittest
import numpy as np
from datasets import Audio, load_dataset
from packaging import version
from parameterized import parameterized
from pytest import mark
@ -745,22 +744,6 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
def test_sdpa_can_compile_dynamic(self):
pass
# For now, Let's focus only on GPU for `torch.compile`
@slow
@require_torch_gpu
def test_torch_compile(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
n_iter = 3
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model.forward = torch.compile(model.forward)
for i in range(n_iter):
_ = model(inputs_dict["input_values"].to(torch_device))
@is_flaky()
def test_batching_equivalence(self):
super().test_batching_equivalence()

View File

@ -15,7 +15,6 @@
"""Testing suite for the PyTorch Mistral model."""
import gc
import tempfile
import unittest
import pytest
@ -416,85 +415,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Mistral apparently does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@ -14,7 +14,6 @@
# limitations under the License.
"""Testing suite for the PyTorch Mixtral model."""
import tempfile
import unittest
import pytest
@ -415,85 +414,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Mixtral apparently does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@ -126,7 +126,6 @@ class MllamaForCausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
all_generative_model_classes = (MllamaForCausalLM,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
_torch_compile_test_ckpt = "nltpt/Llama-3.2-11B-Vision"
def setUp(self):
self.model_tester = MllamaText2TextModelTester(self)

View File

@ -560,7 +560,7 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
return config, input_ids, attention_mask, inputs_dict
def prepare_config_and_inputs_for_generate(self, batch_size=2):
config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate()
config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size)
# Make sure we only return `input_ids`.
# Note that audio_codes will still be generated internally, so the ability to test audio codes is still there.
@ -591,9 +591,11 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
[expected_shape] * len(iter_hidden_states),
)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
super()._check_outputs(output, input_ids, config, use_cache=True, num_return_sequences=num_return_sequences)
super()._check_outputs(
output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
@ -655,59 +657,6 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@pytest.mark.generate
@parameterized.expand([(1,), (2,)])
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
for model_class in self.all_generative_model_classes:
config, input_ids, _, inputs_dict = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"return_dict_in_generate": True,
"output_scores": True,
"num_beams": num_beams,
"do_sample": False,
}
# Traditional way of generating text
outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs, **inputs_dict)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
**generation_kwargs,
**inputs_dict,
)
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
**generation_kwargs,
**inputs_dict,
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=5,
**generation_kwargs,
**inputs_dict,
)
self.assertListEqual(
outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
outputs_from_embeds_wo_ids.sequences.tolist(),
)
@unittest.skip(reason="Continuing from past key values is not straightforward as we're dealing with 3 inputs")
def test_generate_continue_from_past_key_values(self):
pass

View File

@ -576,9 +576,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# The small MT5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/mt5-small"
def setUp(self):
self.model_tester = MT5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37)

View File

@ -450,144 +450,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@ -1585,149 +1447,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask")
if dummy_attention_mask is None:
dummy_attention_mask = torch.ones_like(dummy_input)
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask")
if dummy_attention_mask is None:
dummy_attention_mask = torch.ones_like(dummy_input)
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
if not self.has_attentions:

View File

@ -1437,149 +1437,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
def test_flash_attn_2_generate_left_padding(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask")
if dummy_attention_mask is None:
dummy_attention_mask = torch.ones_like(dummy_input)
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
def test_flash_attn_2_generate_padding_right(self):
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask")
if dummy_attention_mask is None:
dummy_attention_mask = torch.ones_like(dummy_input)
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
def test_flash_attn_2_generate_use_cache(self):
max_new_tokens = 30
# Ignore copy
for model_class in self.greedy_sample_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
if not self.has_attentions:

View File

@ -92,8 +92,6 @@ class NemotronModelTest(GemmaModelTest):
test_pruning = False
fx_compatible = False
# used in `test_torch_compile`
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None

View File

@ -346,10 +346,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(reason="TODO (@joao): fix me -- failing to produce similar results")
def test_static_cache_matches_dynamic(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass

View File

@ -17,15 +17,11 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import PhiConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
@ -468,43 +464,6 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
torch.testing.assert_close(ntk_sin_long, original_sin_long)
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@slow
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = PhiForCausalLM.from_pretrained(
"microsoft/phi-1",
load_in_4bit=True,
device_map={"": 0},
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
texts = ["hi", "Hello this is a very long sentence"]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = PhiForCausalLM.from_pretrained(
"microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(output_native, output_fa_2)
@slow
@require_torch

View File

@ -15,7 +15,6 @@
"""Testing suite for the PyTorch Qwen2 model."""
import gc
import tempfile
import unittest
import pytest
@ -428,85 +427,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Qwen2 apparently does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@ -15,7 +15,6 @@
"""Testing suite for the PyTorch Qwen2MoE model."""
import gc
import tempfile
import unittest
import pytest
@ -453,85 +452,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Qwen2Moe apparently does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@ -301,10 +301,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
pass
@unittest.skip(reason="CPU offload is not yet supported")
def test_cpu_offload(self):
pass

View File

@ -420,10 +420,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
def test_initialization(self):
pass
@unittest.skip(reason="RecurrentGemma does not support generating with input embeddings (missing position_ids)")
def test_inputs_embeds_matches_input_ids_with_generate(self):
pass
@require_torch_accelerator
@slow

View File

@ -14,7 +14,6 @@
# limitations under the License.
"""Testing suite for the PyTorch Starcoder2 model."""
import tempfile
import unittest
import pytest
@ -404,85 +403,6 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Starcoder2 apparently does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@ -580,9 +580,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# The small T5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google-t5/t5-small"
def setUp(self):
self.model_tester = T5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)

View File

@ -317,9 +317,6 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# The small UMT5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/umt5-small"
def setUp(self):
self.model_tester = UMT5ModelTester(self)

View File

@ -1574,59 +1574,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
)
assert isinstance(pred_ids, expected_output_type)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_reuse_cache(self):
max_new_tokens = 2
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name][..., :10]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# run generate once to get filled cache
output = model.generate(
dummy_input,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
past_key_values = output.past_key_values
# Try to continue generation from where we left, given that we have more than 1 new token to process
# e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
_ = model.generate(
dummy_input,
decoder_input_ids=output.sequences,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
past_key_values=past_key_values,
)
def test_labels_sequence_max_length_correct(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -3961,11 +3908,6 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
# generate only works with input ids for whisper
pass
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
# generate only works with input ids for whisper
pass
@unittest.skip(reason="Decoder can't keep attention grads")
def test_retain_grad_hidden_states_attentions(self):
return
@ -3974,18 +3916,6 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(
reason="FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
)
def test_flash_attn_2_generate_reuse_cache(self):
pass
@unittest.skip(
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
)
def test_flash_attn_2_generate_padding_right(self):
pass
@unittest.skip(
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
)

View File

@ -542,93 +542,6 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
r"""
Overriding the test_flash_attn_2_generate_padding_right test as the Zamba model, like Mixtral, doesn't support
right padding + use cache with FA2
"""
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
r"""
Overriding the test_flash_attn_2_generate_use_cache test as the Zamba model, like Mixtral, doesn't support
right padding + use cache with FA2
"""
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Zamba does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@ -22,7 +22,6 @@ import os.path
import random
import re
import tempfile
import time
import warnings
from collections import defaultdict
from contextlib import contextmanager
@ -37,10 +36,7 @@ import transformers
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
GenerationConfig,
PretrainedConfig,
PreTrainedModel,
is_torch_available,
@ -86,7 +82,6 @@ from transformers.testing_utils import (
require_deepspeed,
require_flash_attn,
require_non_xpu,
require_read_token,
require_safetensors,
require_torch,
require_torch_accelerator,
@ -3000,71 +2995,6 @@ class ModelTesterMixin:
)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
def test_inputs_embeds_matches_input_ids_with_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
if model_class.__name__ not in [
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
]:
continue
model = model_class(config)
model.to(torch_device)
model.eval()
model_forward_args = inspect.signature(model.forward).parameters
if any(argument not in model_forward_args for argument in ["inputs_embeds", "position_ids"]):
self.skipTest(reason="This model doesn't use `inputs_embeds` or `position_ids`.")
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(model.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
self.skipTest(reason="This model doesn't support `inputs_embeds` passed to `generate`.")
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
# VLMs can't generate with embeds and pixels at the same time. We expect the user to pass merged
# embeds already
if model_class.__name__ in get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES):
inputs.pop("pixel_values", None)
inputs.pop("pixel_values_videos", None)
inputs.pop("pixel_values_images", None)
wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
# some models infer position ids/attn mask differently when input ids
# by check if pad_token let's make sure no padding is in input ids
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
input_ids[input_ids == pad_token_id] = not_pad_token_id
del inputs["input_ids"]
inputs_embeds = wte(input_ids)
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)[:, -2:]
out_embeds = model.generate(inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
inputs_embeds = wte(encoder_input_ids)
decoder_inputs_embeds = wte(decoder_input_ids)
out_ids = model.generate(
input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs, max_new_tokens=2
)[:, -2:]
out_embeds = model.generate(
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
**inputs,
max_new_tokens=2,
)
# NOTE: this test changes the order of FP ops, there may be tiny differences in the output
number_of_different_tokens = (out_ids != out_embeds).sum()
max_differences = int(out_ids.shape[0] * out_ids.shape[1] * 0.1)
self.assertTrue(number_of_different_tokens <= max_differences) # accept up to 10% mismatch
@require_non_xpu
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
@ -3857,102 +3787,6 @@ class ModelTesterMixin:
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@is_flaky()
def test_flash_attn_2_generate_left_padding(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do left padding
dummy_attention_mask[:, :-1] = 0
dummy_attention_mask[:, -1:] = 1
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@is_flaky()
@slow
def test_flash_attn_2_generate_padding_right(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# make sure we do right padding
dummy_attention_mask[:, :-1] = 1
dummy_attention_mask[:, -1:] = 0
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
self.assertTrue(torch.allclose(out, out_fa))
def test_attn_implementation_composite_models(self):
"""
Tests if composite models can receive a dict object as attn_implementation, where each key should be
@ -4525,65 +4359,6 @@ class ModelTesterMixin:
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
# Generate with one batch only to test generation when attention mask will be None
# when real inputs are used, because there is no padding. See issue #32237 for more
dummy_input = dummy_input[:1, ...]
dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@ -4640,62 +4415,6 @@ class ModelTesterMixin:
if not has_fa2:
raise ValueError("The FA2 model should have FA2 layers")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_reuse_cache(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 2
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# run generate once to get filled cache
output = model.generate(
dummy_input,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
)
past_key_values = output.past_key_values
# Try to continue generation from where we left, given that we have more than 1 new token to process
# e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1)
_ = model.generate(
dummy_input_updated,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
past_key_values=past_key_values,
)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@ -4999,82 +4718,6 @@ class ModelTesterMixin:
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
def test_static_cache_matches_dynamic(self):
"""
Tests that generating with static cache give almost same results as with dynamic cache.
This test does not compile the model and check only logits similarity for numerical precision
errors.
"""
if len(self.all_generative_model_classes) == 0:
self.skipTest(
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
)
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} does not support static cache")
if not model_class._supports_cache_class:
self.skipTest(f"{model_class.__name__} does not support cache class")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
model = model_class(config).to(device=torch_device, dtype=torch.float32)
model.eval()
dynamic_out = model.generate(
**inputs, do_sample=False, max_new_tokens=10, output_logits=True, return_dict_in_generate=True
)
static_out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=10,
cache_implementation="static",
output_logits=True,
return_dict_in_generate=True,
)
self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4))
# For now, Let's focus only on GPU for `torch.compile`
@slow
@require_torch_accelerator
@require_read_token
def test_torch_compile(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
torch.compiler.reset()
if not hasattr(self, "_torch_compile_test_ckpt"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
ckpt = self._torch_compile_test_ckpt
revision = "main" if not hasattr(self, "_torch_compile_test_revision") else self._torch_compile_test_revision
os.environ["TOKENIZERS_PARALLELISM"] = "false"
batch_size = 1
n_iter = 3
tokenizer = AutoTokenizer.from_pretrained(ckpt)
if self.is_encoder_decoder:
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
else:
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
model.generation_config.max_new_tokens = 4
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead")
input_text = "Why dogs are cute?"
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)
for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False)
@slow
@require_torch_gpu
def test_torch_compile_for_training(self):
@ -5118,74 +4761,6 @@ class ModelTesterMixin:
for name, param in model._orig_mod.named_parameters():
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
@slow
@require_torch_gpu # Testing cuda graphs.
@require_read_token
def test_compile_cuda_graph_time(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
# TODO felix: All models supporting `StaticCache` or `torch.compile` should be tested.
# At the moment, only llama, gemma and gemma2 are tested here!
if not hasattr(self, "_torch_compile_test_ckpt"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
ckpt = self._torch_compile_test_ckpt
revision = "main" if not hasattr(self, "_torch_compile_test_revision") else self._torch_compile_test_revision
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
if self.is_encoder_decoder:
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
else:
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
cache_implementation = "static"
if model.config.model_type == "gemma2":
cache_implementation = "hybrid"
new_tokens = 50
gen_config = GenerationConfig(
max_new_tokens=new_tokens,
min_new_tokens=new_tokens,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
num_beams=1,
do_sample=False,
eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
inp = tokenizer("Why cats are cute?", return_tensors="pt").to(torch_device)
# First run: the first run warms up each graph, which does things like CuBlas or Triton benchmarking
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
graph_warmup_time = end - start
# Second run: CUDA Graph recording, and replays it
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
record_time = end - start
# Finally: we hit the optimized, CUDA Graph replay path
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
opt_time = end - start
# For the recording step, we expect only two cuda graphs and this step should be much faster than the first.
self.assertTrue(record_time < 0.15 * graph_warmup_time)
self.assertTrue(opt_time < record_time)
def test_forward_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):