mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Tests: move generate
tests to the right mixin and delete redundant tests (#34464)
* tmp commit * tmp commit * cull overwrites of deleted tests * typo * more specific docstring * make fixup * parameterize at the top? * correction * more deletions :D * tmp commit * for VLMs too * fix _check_outputs * test nit * make fixup * fix another flaky * test_generate_from_inputs_embeds -- handle missing attention mask
This commit is contained in:
parent
913330ca9f
commit
8a734ea2c3
@ -378,10 +378,14 @@ class GenerationMixin:
|
|||||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
|
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
||||||
|
# (we can't check exception 3 while compiling)
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
model_inputs["past_key_values"] = past_key_values
|
model_inputs["past_key_values"] = past_key_values
|
||||||
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
|
if (
|
||||||
|
inputs_embeds is not None # Exception 1
|
||||||
|
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
||||||
|
):
|
||||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
input_ids = input_ids[:, cache_position]
|
input_ids = input_ids[:, cache_position]
|
||||||
@ -414,7 +418,7 @@ class GenerationMixin:
|
|||||||
for model_input_name in ["position_ids", "token_type_ids"]:
|
for model_input_name in ["position_ids", "token_type_ids"]:
|
||||||
model_input = kwargs.get(model_input_name)
|
model_input = kwargs.get(model_input_name)
|
||||||
if model_input is not None:
|
if model_input is not None:
|
||||||
if past_key_values:
|
if past_key_values is not None:
|
||||||
model_input = model_input[:, -input_ids.shape[1] :]
|
model_input = model_input[:, -input_ids.shape[1] :]
|
||||||
model_input = model_input.clone(memory_format=torch.contiguous_format)
|
model_input = model_input.clone(memory_format=torch.contiguous_format)
|
||||||
model_inputs[model_input_name] = model_input
|
model_inputs[model_input_name] = model_input
|
||||||
@ -568,27 +572,34 @@ class GenerationMixin:
|
|||||||
|
|
||||||
def _prepare_attention_mask_for_generation(
|
def _prepare_attention_mask_for_generation(
|
||||||
self,
|
self,
|
||||||
inputs: torch.Tensor,
|
inputs_tensor: torch.Tensor,
|
||||||
pad_token_id: Optional[torch.Tensor],
|
generation_config: GenerationConfig,
|
||||||
eos_token_id: Optional[torch.Tensor],
|
model_kwargs: Dict[str, Any],
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
|
pad_token_id = generation_config._pad_token_tensor
|
||||||
|
eos_token_id = generation_config._eos_token_tensor
|
||||||
|
|
||||||
|
# `input_ids` may be present in the model kwargs, instead of being the main input (e.g. multimodal model)
|
||||||
|
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
|
||||||
|
inputs_tensor = model_kwargs["input_ids"]
|
||||||
|
|
||||||
# No information for attention mask inference -> return default attention mask
|
# No information for attention mask inference -> return default attention mask
|
||||||
default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
|
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
|
||||||
if pad_token_id is None:
|
if pad_token_id is None:
|
||||||
return default_attention_mask
|
return default_attention_mask
|
||||||
|
|
||||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
|
||||||
if not is_input_ids:
|
if not is_input_ids:
|
||||||
return default_attention_mask
|
return default_attention_mask
|
||||||
|
|
||||||
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
||||||
isin_mps_friendly(elements=inputs, test_elements=pad_token_id).any()
|
isin_mps_friendly(elements=inputs_tensor, test_elements=pad_token_id).any()
|
||||||
)
|
)
|
||||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
||||||
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
|
isin_mps_friendly(elements=eos_token_id, test_elements=pad_token_id).any()
|
||||||
)
|
)
|
||||||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
||||||
attention_mask_from_padding = inputs.ne(pad_token_id).long()
|
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
||||||
@ -2020,7 +2031,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
inputs_tensor, generation_config, model_kwargs
|
||||||
)
|
)
|
||||||
elif kwargs_has_attention_mask:
|
elif kwargs_has_attention_mask:
|
||||||
# TODO (joao): generalize this check with other types of inputs
|
# TODO (joao): generalize this check with other types of inputs
|
||||||
|
@ -911,7 +911,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
|||||||
|
|
||||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||||
|
"and must specify either one"
|
||||||
)
|
)
|
||||||
|
|
||||||
legacy_processing = False
|
legacy_processing = False
|
||||||
|
@ -424,7 +424,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
|
|||||||
|
|
||||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||||
|
"and must specify either one"
|
||||||
)
|
)
|
||||||
|
|
||||||
legacy_processing = False
|
legacy_processing = False
|
||||||
|
@ -657,7 +657,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
|||||||
|
|
||||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You cannot specify both pixel_values/pixel_values_videos and inputs_embeds at the same time, and must specify either one"
|
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||||
|
"and must specify either one"
|
||||||
)
|
)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
|
@ -1562,7 +1562,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
input_ids, generation_config, model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Prepare `max_length` depending on other stopping criteria.
|
# 5. Prepare `max_length` depending on other stopping criteria.
|
||||||
@ -2578,7 +2578,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
inputs_tensor, generation_config, model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if "encoder_outputs" not in model_kwargs:
|
if "encoder_outputs" not in model_kwargs:
|
||||||
|
@ -1484,7 +1484,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
input_ids, generation_config, model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Prepare `max_length` depending on other stopping criteria.
|
# 5. Prepare `max_length` depending on other stopping criteria.
|
||||||
@ -2425,7 +2425,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
|
inputs_tensor, generation_config, model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if "encoder_hidden_states" not in model_kwargs:
|
if "encoder_hidden_states" not in model_kwargs:
|
||||||
|
@ -534,7 +534,8 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
|||||||
|
|
||||||
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
"You cannot specify both `pixel_values_images`/`pixel_values_videos` and `inputs_embeds` at the same "
|
||||||
|
"time, and must specify either one"
|
||||||
)
|
)
|
||||||
|
|
||||||
legacy_processing = False
|
legacy_processing = False
|
||||||
|
@ -29,6 +29,7 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
|
require_flash_attn,
|
||||||
require_optimum_quanto,
|
require_optimum_quanto,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@ -136,6 +137,34 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
return config, filtered_inputs_dict
|
return config, filtered_inputs_dict
|
||||||
|
|
||||||
|
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
|
||||||
|
"""
|
||||||
|
Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in
|
||||||
|
the following siturations:
|
||||||
|
1. The sequences are the same
|
||||||
|
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
|
||||||
|
"""
|
||||||
|
# scores doesn't include data regarding decoder input tokens
|
||||||
|
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
|
||||||
|
output_matches = output_1.sequences == output_2.sequences
|
||||||
|
has_matching_outputs = output_matches.all()
|
||||||
|
has_matching_scores = None
|
||||||
|
if not has_matching_outputs:
|
||||||
|
for batch_idx in range(output_1.sequences.shape[0]):
|
||||||
|
batch_matches = output_matches[batch_idx]
|
||||||
|
if batch_matches.all():
|
||||||
|
continue
|
||||||
|
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
||||||
|
first_mismatch_idx -= decoder_input_length
|
||||||
|
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
|
||||||
|
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
|
||||||
|
has_matching_scores = torch.allclose(
|
||||||
|
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
|
||||||
|
)
|
||||||
|
if not has_matching_scores:
|
||||||
|
break
|
||||||
|
self.assertTrue(has_matching_outputs or has_matching_scores)
|
||||||
|
|
||||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||||
logits_processor_kwargs = {
|
logits_processor_kwargs = {
|
||||||
"bad_words_ids": [[1, 0]],
|
"bad_words_ids": [[1, 0]],
|
||||||
@ -426,7 +455,6 @@ class GenerationTesterMixin:
|
|||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
@ -453,13 +481,12 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(output_generate, main_input, model.config)
|
self._check_outputs(output_generate, model.config)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
@ -486,7 +513,7 @@ class GenerationTesterMixin:
|
|||||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
|
self._check_outputs(output_generate, model.config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_sample_generate(self):
|
def test_sample_generate(self):
|
||||||
@ -505,7 +532,6 @@ class GenerationTesterMixin:
|
|||||||
def test_sample_generate_dict_output(self):
|
def test_sample_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
@ -533,7 +559,7 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(output_generate, main_input, model.config, num_return_sequences=2)
|
self._check_outputs(output_generate, model.config, num_return_sequences=2)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_beam_search_generate(self):
|
def test_beam_search_generate(self):
|
||||||
@ -554,7 +580,6 @@ class GenerationTesterMixin:
|
|||||||
def test_beam_search_generate_dict_output(self):
|
def test_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
@ -583,14 +608,16 @@ class GenerationTesterMixin:
|
|||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
output_generate,
|
||||||
|
model.config,
|
||||||
|
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||||
|
num_beams=beam_kwargs["num_beams"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
@ -623,10 +650,10 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate,
|
output_generate,
|
||||||
main_input,
|
|
||||||
model.config,
|
model.config,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
num_return_sequences=beam_kwargs["num_beams"],
|
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||||
|
num_beams=beam_kwargs["num_beams"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@ -675,7 +702,6 @@ class GenerationTesterMixin:
|
|||||||
def test_beam_sample_generate_dict_output(self):
|
def test_beam_sample_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
@ -706,7 +732,10 @@ class GenerationTesterMixin:
|
|||||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
output_generate,
|
||||||
|
model.config,
|
||||||
|
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||||
|
num_beams=beam_kwargs["num_beams"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@ -765,7 +794,6 @@ class GenerationTesterMixin:
|
|||||||
def test_group_beam_search_generate_dict_output(self):
|
def test_group_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||||
@ -794,7 +822,10 @@ class GenerationTesterMixin:
|
|||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
output_generate,
|
||||||
|
model.config,
|
||||||
|
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||||
|
num_beams=beam_kwargs["num_beams"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: @gante check why it is flaky
|
# TODO: @gante check why it is flaky
|
||||||
@ -859,7 +890,6 @@ class GenerationTesterMixin:
|
|||||||
def test_constrained_beam_search_generate_dict_output(self):
|
def test_constrained_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
@ -899,7 +929,10 @@ class GenerationTesterMixin:
|
|||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
output_generate,
|
||||||
|
model.config,
|
||||||
|
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||||
|
num_beams=beam_kwargs["num_beams"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@ -942,7 +975,6 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@ -968,7 +1000,7 @@ class GenerationTesterMixin:
|
|||||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
|
self._check_outputs(output_generate, model.config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_contrastive_generate_low_memory(self):
|
def test_contrastive_generate_low_memory(self):
|
||||||
@ -1064,14 +1096,10 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||||
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
|
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
||||||
# shape differences -- and it may result in a different output. The input shape difference happens in the
|
# NOTE: It breaks the pattern in the tests above, for multiple reasons:
|
||||||
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
|
|
||||||
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
|
||||||
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
|
|
||||||
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
|
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
|
||||||
# prepare the assistant encoder outputs in the main generate body);
|
# prepare the assistant encoder outputs in the main generate body);
|
||||||
# - assisted_decoding does not support `use_cache = False`
|
# - assisted_decoding does not support `use_cache = False`
|
||||||
@ -1100,7 +1128,6 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@ -1141,12 +1168,10 @@ class GenerationTesterMixin:
|
|||||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
|
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
||||||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
|
||||||
for output in (output_greedy, output_assisted):
|
for output in (output_greedy, output_assisted):
|
||||||
self._check_outputs(output, main_input, model.config, use_cache=True)
|
self._check_outputs(output, model.config, use_cache=True)
|
||||||
|
|
||||||
@is_flaky()
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||||
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
|
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
|
||||||
@ -1175,7 +1200,6 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@ -1208,10 +1232,9 @@ class GenerationTesterMixin:
|
|||||||
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
|
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
|
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
|
||||||
self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
|
|
||||||
for output in (output_greedy, output_prompt_lookup):
|
for output in (output_greedy, output_prompt_lookup):
|
||||||
self._check_outputs(output, main_input, model.config, use_cache=True)
|
self._check_outputs(output, model.config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_dola_decoding_sample(self):
|
def test_dola_decoding_sample(self):
|
||||||
@ -1231,7 +1254,6 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
# Encoder-decoder models are not supported
|
# Encoder-decoder models are not supported
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
@ -1259,7 +1281,7 @@ class GenerationTesterMixin:
|
|||||||
"dola_layers": "low",
|
"dola_layers": "low",
|
||||||
}
|
}
|
||||||
output_dola = model.generate(**generation_kwargs, **inputs_dict)
|
output_dola = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
self._check_outputs(output_dola, main_input, model.config, use_cache=getattr(config, "use_cache", False))
|
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
@ -1289,7 +1311,6 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@ -1321,7 +1342,7 @@ class GenerationTesterMixin:
|
|||||||
}
|
}
|
||||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
self._check_outputs(output_assisted, main_input, config, use_cache=True)
|
self._check_outputs(output_assisted, config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_prompt_lookup_decoding_stops_at_eos(self):
|
def test_prompt_lookup_decoding_stops_at_eos(self):
|
||||||
@ -1547,75 +1568,93 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@parameterized.expand([(1,), (2,)])
|
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||||
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
|
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||||
|
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
|
||||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
|
||||||
# Ignore:
|
|
||||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
|
||||||
# which would cause a mismatch),
|
|
||||||
config.pad_token_id = config.eos_token_id = -1
|
|
||||||
# b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the
|
|
||||||
# variable that holds the scaling factor, which is model-dependent)
|
|
||||||
if hasattr(config, "scale_embedding"):
|
|
||||||
config.scale_embedding = False
|
|
||||||
|
|
||||||
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
|
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
|
||||||
# decoder)
|
# decoder)
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
continue
|
continue
|
||||||
|
config.is_decoder = True
|
||||||
|
|
||||||
# Skip models without explicit support
|
# Skip models without explicit support
|
||||||
config.is_decoder = True
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# There are a few exception patterns in this test:
|
||||||
|
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed
|
||||||
|
requires_inputs_ids = any(
|
||||||
|
model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl"]
|
||||||
|
)
|
||||||
|
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
|
||||||
|
# than calling the embedding layer with `input_ids`. Subcases of this exception:
|
||||||
|
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
|
||||||
|
if hasattr(config, "scale_embedding"):
|
||||||
|
config.scale_embedding = False
|
||||||
|
# 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
|
||||||
|
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
|
||||||
|
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
|
||||||
|
pixel_values_is_mutually_exclusive = any(
|
||||||
|
model_name in model_class.__name__.lower()
|
||||||
|
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma"]
|
||||||
|
)
|
||||||
|
if pixel_values_is_mutually_exclusive:
|
||||||
|
inputs_dict.pop("pixel_values", None)
|
||||||
|
inputs_dict.pop("pixel_values_videos", None)
|
||||||
|
inputs_dict.pop("pixel_values_images", None)
|
||||||
|
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
|
||||||
|
has_complex_embeds_computation = any(
|
||||||
|
model_name in model_class.__name__.lower() for model_name in ["moshi"]
|
||||||
|
)
|
||||||
|
# 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
|
||||||
|
# we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
|
||||||
|
missing_attention_mask = "attention_mask" not in inputs_dict
|
||||||
|
|
||||||
|
# Traditional way of generating text
|
||||||
input_ids = inputs_dict.pop("input_ids")
|
input_ids = inputs_dict.pop("input_ids")
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
"output_scores": True,
|
"output_scores": True,
|
||||||
"num_beams": num_beams,
|
"num_beams": num_beams,
|
||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
|
"max_new_tokens": 5,
|
||||||
|
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||||
}
|
}
|
||||||
|
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
|
||||||
# Traditional way of generating text
|
|
||||||
outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs)
|
|
||||||
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||||
|
|
||||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
|
||||||
|
# The output of the two calls should be the same.
|
||||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||||
outputs_from_embeds = model.generate(
|
outputs_from_embeds = model.generate(
|
||||||
input_ids,
|
input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
max_new_tokens=5,
|
|
||||||
**generation_kwargs,
|
|
||||||
)
|
)
|
||||||
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
|
if not has_complex_embeds_computation:
|
||||||
|
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
|
||||||
|
|
||||||
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
|
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
|
||||||
# same, but the logits will almost surely be different)
|
# same, but the logits will almost surely be different)
|
||||||
random_embeds = torch.rand_like(inputs_embeds)
|
random_embeds = torch.rand_like(inputs_embeds)
|
||||||
outputs_from_rand_embeds = model.generate(
|
outputs_from_rand_embeds = model.generate(
|
||||||
input_ids,
|
input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict
|
||||||
inputs_embeds=random_embeds,
|
|
||||||
max_new_tokens=5,
|
|
||||||
**generation_kwargs,
|
|
||||||
)
|
)
|
||||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
for i in range(len(outputs_from_rand_embeds.scores)):
|
||||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
||||||
|
|
||||||
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
||||||
outputs_from_embeds_wo_ids = model.generate(
|
# be the same
|
||||||
inputs_embeds=inputs_embeds, max_new_tokens=5, **generation_kwargs
|
if not (requires_inputs_ids or missing_attention_mask):
|
||||||
)
|
outputs_from_embeds_wo_ids = model.generate(
|
||||||
self.assertListEqual(
|
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||||
outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
|
)
|
||||||
outputs_from_embeds_wo_ids.sequences.tolist(),
|
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
|
||||||
)
|
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
@ -1829,10 +1868,8 @@ class GenerationTesterMixin:
|
|||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_with_static_cache(self):
|
def test_generate_with_static_cache(self):
|
||||||
"""
|
"""
|
||||||
Tests if StaticCache works if we set attn_implementation=static when generation.
|
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
|
||||||
This doesn't test if generation quality is good, but tests that models with
|
has the expected shapes
|
||||||
self._supports_static_cache don't throw an error when generating and return
|
|
||||||
a StaticCache object at the end.
|
|
||||||
"""
|
"""
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if not model_class._supports_static_cache:
|
if not model_class._supports_static_cache:
|
||||||
@ -1851,13 +1888,15 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"max_length": None,
|
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
"cache_implementation": "static",
|
|
||||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
|
"output_scores": True,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static")
|
||||||
|
|
||||||
|
# Check 1: The cache shapes must match the expected shapes
|
||||||
max_cache_len = seq_length + max_new_tokens
|
max_cache_len = seq_length + max_new_tokens
|
||||||
config = config.text_config if hasattr(config, "text_config") else config
|
config = config.text_config if hasattr(config, "text_config") else config
|
||||||
head_dim = (
|
head_dim = (
|
||||||
@ -1869,12 +1908,14 @@ class GenerationTesterMixin:
|
|||||||
else config.num_key_value_heads
|
else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
num_hidden_layers = config.num_hidden_layers
|
num_hidden_layers = config.num_hidden_layers
|
||||||
results = model.generate(**generation_kwargs, **inputs_dict)
|
|
||||||
|
|
||||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
|
||||||
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
|
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
|
||||||
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
|
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
|
||||||
|
|
||||||
|
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||||
|
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
|
||||||
|
|
||||||
@require_optimum_quanto
|
@require_optimum_quanto
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@ -1908,25 +1949,32 @@ class GenerationTesterMixin:
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
model.generate(**generation_kwargs, **inputs_dict)
|
model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
|
||||||
|
("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
|
||||||
|
]
|
||||||
|
)
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
|
def test_generate_compile(self, _, end_to_end):
|
||||||
def test_generate_compile_fullgraph(self):
|
|
||||||
"""
|
"""
|
||||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
|
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
|
||||||
|
end-to-end compilation and forward pass compilation only.
|
||||||
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
||||||
"""
|
"""
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if not model_class._supports_static_cache:
|
if not model_class._supports_static_cache:
|
||||||
self.skipTest("This model doesn't support static cache")
|
self.skipTest("This model doesn't support static cache")
|
||||||
|
|
||||||
# TODO (joao) -- fix and enable me :)
|
# TODO (joao) -- fix and enable me :)
|
||||||
if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
|
if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
|
||||||
self.skipTest("whisper model end-to-end generate compile not yet supported")
|
self.skipTest("whisper model end-to-end generate compile not yet supported")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
# TODO (joao) -- fix and enable me :)
|
# TODO (joao) -- fix and enable me :)
|
||||||
if config.is_encoder_decoder:
|
if end_to_end and config.is_encoder_decoder:
|
||||||
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
|
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
|
||||||
|
|
||||||
model = model_class(config).to(torch_device)
|
model = model_class(config).to(torch_device)
|
||||||
@ -1941,27 +1989,33 @@ class GenerationTesterMixin:
|
|||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
"max_new_tokens": 10,
|
"max_new_tokens": 10,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"output_scores": True,
|
||||||
}
|
}
|
||||||
|
# end-to-end works best with dynamic cache, forward compilation works best with static cache
|
||||||
|
if not end_to_end:
|
||||||
|
generation_kwargs["cache_implementation"] = "static"
|
||||||
|
|
||||||
max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"]
|
# get eager + dynamic cache results for future comparison
|
||||||
config = config.get_text_config()
|
dynamic_outputs = []
|
||||||
past_key_values = StaticCache(
|
|
||||||
config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
for model_inputs in input_ids_sets:
|
for model_inputs in input_ids_sets:
|
||||||
# eager dynamic cache
|
dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs))
|
||||||
output_dynamic = model.generate(model_inputs, **generation_kwargs)
|
|
||||||
|
|
||||||
# end-to-end compiled dynamic cache
|
# get compiled results
|
||||||
torch.compiler.reset()
|
generation_config = copy.deepcopy(model.generation_config)
|
||||||
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
generation_config.update(**generation_kwargs)
|
||||||
generation_config = copy.deepcopy(model.generation_config)
|
torch.compiler.reset()
|
||||||
generation_config.update(**generation_kwargs)
|
if end_to_end:
|
||||||
output_compiled = compiled_generate(
|
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||||
model_inputs, generation_config=generation_config, past_key_values=past_key_values
|
else:
|
||||||
)
|
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
|
||||||
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
|
||||||
|
compiled_outputs = []
|
||||||
|
for model_inputs in input_ids_sets:
|
||||||
|
compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config))
|
||||||
|
|
||||||
|
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||||
|
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_methods_with_num_logits_to_keep(self):
|
def test_generate_methods_with_num_logits_to_keep(self):
|
||||||
@ -1989,7 +2043,6 @@ class GenerationTesterMixin:
|
|||||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
|
|
||||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||||
@ -1998,6 +2051,9 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||||
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
|
if not hasattr(config, "use_cache"):
|
||||||
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
@ -2010,14 +2066,16 @@ class GenerationTesterMixin:
|
|||||||
"max_new_tokens": 10,
|
"max_new_tokens": 10,
|
||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
"assistant_model": assistant_model,
|
"assistant_model": assistant_model,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"output_scores": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
assistant_model.generation_config.assistant_confidence_threshold = None
|
|
||||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
|
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
|
||||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
||||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
|
||||||
|
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_inherits_generation_mixin(self):
|
def test_inherits_generation_mixin(self):
|
||||||
@ -2028,14 +2086,21 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
|
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
|
||||||
|
|
||||||
@require_torch_sdpa
|
def _test_attention_implementation(self, attn_implementation):
|
||||||
@slow
|
"""
|
||||||
def test_eager_matches_sdpa_generate(self):
|
Compares the output of generate with the eager attention implementation against other implementations.
|
||||||
|
NOTE: despite the test logic being the same, different implementations actually need diferent decorators, hence
|
||||||
|
this separate function.
|
||||||
|
"""
|
||||||
max_new_tokens = 30
|
max_new_tokens = 30
|
||||||
|
support_flag = {
|
||||||
|
"sdpa": "_supports_sdpa",
|
||||||
|
"flash_attention_2": "_supports_flash_attn_2",
|
||||||
|
}
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if not model_class._supports_sdpa:
|
if not getattr(model_class, support_flag[attn_implementation]):
|
||||||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
|
||||||
|
|
||||||
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
inputs_dict = {}
|
inputs_dict = {}
|
||||||
@ -2062,17 +2127,9 @@ class GenerationTesterMixin:
|
|||||||
"do_sample": False,
|
"do_sample": False,
|
||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
"output_scores": True,
|
"output_scores": True,
|
||||||
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
model_sdpa = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
res_sdpa = model_sdpa.generate(**inputs_dict, **generate_kwargs)
|
|
||||||
del model_sdpa
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
model_eager = model_class.from_pretrained(
|
model_eager = model_class.from_pretrained(
|
||||||
tmpdirname,
|
tmpdirname,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
@ -2083,42 +2140,46 @@ class GenerationTesterMixin:
|
|||||||
del model_eager
|
del model_eager
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# Eager and SDPA are very similar, but not exactly the same. Because we are using random models, this
|
model_attn = model_class.from_pretrained(
|
||||||
# test would be flaky if we only checked the sequences. Two situations in which this test passes:
|
tmpdirname,
|
||||||
# 1. The sequences are the same
|
torch_dtype=torch.float16,
|
||||||
# 2. The sequences are different, but the scores up until the first mismatch are nearly identical
|
low_cpu_mem_usage=True,
|
||||||
output_matches = res_eager.sequences == res_sdpa.sequences
|
attn_implementation=attn_implementation,
|
||||||
has_matching_outputs = output_matches.all()
|
).to(torch_device)
|
||||||
has_matching_scores = None
|
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
|
||||||
if not has_matching_outputs:
|
del model_attn
|
||||||
input_length = main_input.shape[1]
|
gc.collect()
|
||||||
for batch_idx in range(res_eager.sequences.shape[0]):
|
|
||||||
batch_matches = output_matches[batch_idx]
|
|
||||||
if batch_matches.all():
|
|
||||||
continue
|
|
||||||
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
|
||||||
first_mismatch_idx -= input_length # scores doesn't include data regarding input tokens
|
|
||||||
sdpa_first_mismatch_scores = res_sdpa.scores[first_mismatch_idx][batch_idx]
|
|
||||||
eager_first_mismatch_scores = res_eager.scores[first_mismatch_idx][batch_idx]
|
|
||||||
has_matching_scores = torch.allclose(
|
|
||||||
sdpa_first_mismatch_scores, eager_first_mismatch_scores, rtol=1e-3, atol=1e-3
|
|
||||||
)
|
|
||||||
if not has_matching_scores:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.assertTrue(has_matching_outputs or has_matching_scores)
|
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
|
@pytest.mark.generate
|
||||||
# we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image
|
@require_torch_sdpa
|
||||||
# so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests`
|
@slow
|
||||||
batch_size = main_input.shape[0]
|
def test_eager_matches_sdpa_generate(self):
|
||||||
|
"""Tests that generate has equivalent outputs with SDPA and eager attention implementations."""
|
||||||
|
self._test_attention_implementation("sdpa")
|
||||||
|
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@slow
|
||||||
|
def test_eager_matches_fa2_generate(self):
|
||||||
|
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
|
||||||
|
# TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing,
|
||||||
|
# check whether we still need the overwrites
|
||||||
|
self._test_attention_implementation("flash_attention_2")
|
||||||
|
|
||||||
|
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||||
|
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
||||||
|
internal_batch_size = (
|
||||||
|
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
|
||||||
|
)
|
||||||
|
|
||||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
seq_length = getattr(self.model_tester, "seq_length", None)
|
||||||
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
||||||
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
||||||
|
|
||||||
config = config.text_config if hasattr(config, "text_config") else config
|
config = config.text_config if hasattr(config, "text_config") else config
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
|
||||||
|
|
||||||
gen_len = (
|
gen_len = (
|
||||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||||
@ -2129,19 +2190,21 @@ class GenerationTesterMixin:
|
|||||||
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||||
|
|
||||||
# scores
|
# scores
|
||||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
|
||||||
|
|
||||||
# unprocessed logits
|
# unprocessed logits
|
||||||
self._check_logits(num_sequences_in_output, output.logits, config=config)
|
self._check_logits(internal_batch_size, output.logits, config=config)
|
||||||
|
|
||||||
# Attentions
|
# Attentions
|
||||||
if self.has_attentions:
|
if self.has_attentions:
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
# encoder
|
# encoder
|
||||||
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
|
self._check_encoder_attention_for_generate(
|
||||||
|
output.encoder_attentions, input_batch_size, config, seq_length
|
||||||
|
)
|
||||||
# decoder
|
# decoder
|
||||||
self._check_attentions_for_generate(
|
self._check_attentions_for_generate(
|
||||||
num_sequences_in_output,
|
internal_batch_size,
|
||||||
output.decoder_attentions,
|
output.decoder_attentions,
|
||||||
min_length=1,
|
min_length=1,
|
||||||
max_length=output.sequences.shape[-1],
|
max_length=output.sequences.shape[-1],
|
||||||
@ -2153,7 +2216,7 @@ class GenerationTesterMixin:
|
|||||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
attentions = output.attentions if not use_cache else output.attentions[1:]
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
min_length = seq_length if not use_cache else seq_length + 1
|
||||||
self._check_attentions_for_generate(
|
self._check_attentions_for_generate(
|
||||||
num_sequences_in_output,
|
internal_batch_size,
|
||||||
attentions=attentions,
|
attentions=attentions,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
max_length=output.sequences.shape[-1],
|
max_length=output.sequences.shape[-1],
|
||||||
@ -2165,12 +2228,12 @@ class GenerationTesterMixin:
|
|||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
# encoder
|
# encoder
|
||||||
self._check_encoder_hidden_states_for_generate(
|
self._check_encoder_hidden_states_for_generate(
|
||||||
output.encoder_hidden_states, batch_size, config, seq_length
|
output.encoder_hidden_states, input_batch_size, config, seq_length
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
self._check_hidden_states_for_generate(
|
self._check_hidden_states_for_generate(
|
||||||
num_sequences_in_output,
|
internal_batch_size,
|
||||||
output.decoder_hidden_states,
|
output.decoder_hidden_states,
|
||||||
min_length=1,
|
min_length=1,
|
||||||
max_length=output.sequences.shape[-1],
|
max_length=output.sequences.shape[-1],
|
||||||
@ -2182,7 +2245,7 @@ class GenerationTesterMixin:
|
|||||||
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
|
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
min_length = seq_length if not use_cache else seq_length + 1
|
||||||
self._check_hidden_states_for_generate(
|
self._check_hidden_states_for_generate(
|
||||||
num_sequences_in_output,
|
internal_batch_size,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
max_length=output.sequences.shape[-1],
|
max_length=output.sequences.shape[-1],
|
||||||
@ -2213,7 +2276,7 @@ class GenerationTesterMixin:
|
|||||||
past_key_values = output.past_key_values
|
past_key_values = output.past_key_values
|
||||||
past_sequence_length = output.sequences.shape[-1] - 1
|
past_sequence_length = output.sequences.shape[-1] - 1
|
||||||
self._check_past_key_values_for_generate(
|
self._check_past_key_values_for_generate(
|
||||||
num_sequences_in_output,
|
internal_batch_size,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
seq_length=past_sequence_length,
|
seq_length=past_sequence_length,
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -1532,8 +1532,3 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
|
|||||||
@unittest.skip
|
@unittest.skip
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Generate needs input ids")
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
|
||||||
# generate only works with input ids for bartforcausalLM
|
|
||||||
pass
|
|
||||||
|
@ -511,11 +511,6 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skip(reason="Generate needs input ids")
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
|
||||||
# generate only works with input ids for bertforcausalLM
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_model_as_decoder_with_default_input_mask(self):
|
def test_model_as_decoder_with_default_input_mask(self):
|
||||||
# This regression test was failing with PyTorch < 1.3
|
# This regression test was failing with PyTorch < 1.3
|
||||||
(
|
(
|
||||||
|
@ -16,17 +16,14 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
|
||||||
import requests
|
import requests
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed
|
from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@ -329,43 +326,6 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
# The output should be different for long inputs
|
# The output should be different for long inputs
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_read_token
|
|
||||||
@require_torch_gpu
|
|
||||||
@require_bitsandbytes
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
"""
|
|
||||||
Overwritting the common test as the test is flaky on tiny models
|
|
||||||
"""
|
|
||||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
|
||||||
"facebook/chameleon-7b",
|
|
||||||
load_in_4bit=True,
|
|
||||||
device_map={"": 0},
|
|
||||||
)
|
|
||||||
|
|
||||||
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
|
|
||||||
texts = ["hi", "Hello this is a very long sentence"]
|
|
||||||
|
|
||||||
processor.tokenizer.padding_side = "right"
|
|
||||||
|
|
||||||
inputs = processor(text=texts, return_tensors="pt", padding=True).to(0)
|
|
||||||
|
|
||||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_native = processor.tokenizer.batch_decode(output_native)
|
|
||||||
|
|
||||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
|
||||||
"facebook/chameleon-7b",
|
|
||||||
load_in_4bit=True,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
)
|
|
||||||
|
|
||||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_fa_2 = processor.tokenizer.batch_decode(output_fa_2)
|
|
||||||
|
|
||||||
self.assertListEqual(output_native, output_fa_2)
|
|
||||||
|
|
||||||
@unittest.skip("Chameleon forces some token ids to be -inf!")
|
@unittest.skip("Chameleon forces some token ids to be -inf!")
|
||||||
def test_batching_equivalence(self):
|
def test_batching_equivalence(self):
|
||||||
pass
|
pass
|
||||||
|
@ -319,9 +319,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
model_split_percents = [0.5, 0.6]
|
model_split_percents = [0.5, 0.6]
|
||||||
|
|
||||||
# used in `test_torch_compile`
|
|
||||||
_torch_compile_test_ckpt = "google/gemma-2b"
|
|
||||||
|
|
||||||
# used in `test_torch_compile_for_training`
|
# used in `test_torch_compile_for_training`
|
||||||
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None
|
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
@ -419,51 +416,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# NOTE: Gemma apparently does not support right padding + use_cache with FA2.
|
|
||||||
dummy_attention_mask[:, -1] = 1
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -78,7 +78,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
_is_stateful = True
|
_is_stateful = True
|
||||||
model_split_percents = [0.5, 0.6]
|
model_split_percents = [0.5, 0.6]
|
||||||
_torch_compile_test_ckpt = "google/gemma-2-9b"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Gemma2ModelTester(self)
|
self.model_tester = Gemma2ModelTester(self)
|
||||||
|
@ -28,7 +28,6 @@ from transformers.testing_utils import (
|
|||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
|
||||||
require_torch_sdpa,
|
require_torch_sdpa,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@ -306,10 +305,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
# used in `test_torch_compile`
|
|
||||||
_torch_compile_test_ckpt = "THUDM/glm-4-9b"
|
|
||||||
_torch_compile_test_revision = "refs/pr/15"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = GlmModelTester(self)
|
self.model_tester = GlmModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=GlmConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=GlmConfig, hidden_size=37)
|
||||||
@ -426,41 +421,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3)
|
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
"""Overwrite the common test as the test is flaky on tiny models."""
|
|
||||||
model = GlmForCausalLM.from_pretrained(
|
|
||||||
"THUDM/glm-4-9b",
|
|
||||||
device_map={"": 0},
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
revision="refs/pr/15",
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b", revision="refs/pr/15")
|
|
||||||
tokenizer.padding_side = "right"
|
|
||||||
|
|
||||||
texts = ["hi", "Hello this is a very long sentence"]
|
|
||||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
|
||||||
|
|
||||||
output_native = model.generate(**inputs, max_new_tokens=15, do_sample=False)
|
|
||||||
output_native = tokenizer.batch_decode(output_native)
|
|
||||||
|
|
||||||
model = GlmForCausalLM.from_pretrained(
|
|
||||||
"THUDM/glm-4-9b",
|
|
||||||
device_map={"": 0},
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
revision="refs/pr/15",
|
|
||||||
)
|
|
||||||
|
|
||||||
output_fa_2 = model.generate(**inputs, max_new_tokens=15, do_sample=False)
|
|
||||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
|
||||||
|
|
||||||
self.assertListEqual(output_native, output_fa_2)
|
|
||||||
|
|
||||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
@slow
|
@slow
|
||||||
|
@ -17,14 +17,9 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
from transformers import GPTJConfig, is_torch_available
|
||||||
|
|
||||||
from transformers import BitsAndBytesConfig, GPTJConfig, is_torch_available
|
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
|
||||||
require_flash_attn,
|
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
|
||||||
slow,
|
slow,
|
||||||
tooslow,
|
tooslow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@ -505,44 +500,6 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16)
|
model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@require_bitsandbytes
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
"""
|
|
||||||
Overwritting the common test as the test is flaky on tiny models
|
|
||||||
"""
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
|
|
||||||
|
|
||||||
texts = ["hi", "Hello this is a very long sentence"]
|
|
||||||
expected_outputs = [
|
|
||||||
"hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the",
|
|
||||||
"Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it",
|
|
||||||
]
|
|
||||||
|
|
||||||
tokenizer.padding_side = "right"
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
|
||||||
|
|
||||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
|
||||||
|
|
||||||
model = GPTJForCausalLM.from_pretrained(
|
|
||||||
"EleutherAI/gpt-j-6b",
|
|
||||||
device_map={"": 0},
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
revision="float16",
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
|
||||||
|
|
||||||
self.assertListEqual(expected_outputs, output_fa_2)
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GPTJModelLanguageGenerationTest(unittest.TestCase):
|
class GPTJModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
@ -17,12 +17,10 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoTokenizer, GraniteConfig, is_torch_available, set_seed
|
from transformers import GraniteConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
@ -303,9 +301,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
model_split_percents = [0.5, 0.7, 0.8]
|
model_split_percents = [0.5, 0.7, 0.8]
|
||||||
|
|
||||||
# used in `test_torch_compile`
|
|
||||||
_torch_compile_test_ckpt = "ibm/PowerLM-3b"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = GraniteModelTester(self)
|
self.model_tester = GraniteModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=GraniteConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=GraniteConfig, hidden_size=37)
|
||||||
@ -423,46 +418,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@require_bitsandbytes
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@require_read_token
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
"""
|
|
||||||
Overwritting the common test as the test is flaky on tiny models
|
|
||||||
"""
|
|
||||||
model = GraniteForCausalLM.from_pretrained(
|
|
||||||
"ibm/PowerLM-3b",
|
|
||||||
load_in_4bit=True,
|
|
||||||
device_map={"": 0},
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b")
|
|
||||||
|
|
||||||
texts = ["hi", "Hello this is a very long sentence"]
|
|
||||||
|
|
||||||
tokenizer.padding_side = "right"
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
|
||||||
|
|
||||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_native = tokenizer.batch_decode(output_native)
|
|
||||||
|
|
||||||
model = GraniteForCausalLM.from_pretrained(
|
|
||||||
"ibm/PowerLM-3b",
|
|
||||||
load_in_4bit=True,
|
|
||||||
device_map={"": 0},
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
)
|
|
||||||
|
|
||||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
|
||||||
|
|
||||||
self.assertListEqual(output_native, output_fa_2)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
|
@ -17,12 +17,10 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoTokenizer, GraniteMoeConfig, is_torch_available, set_seed
|
from transformers import AutoTokenizer, GraniteMoeConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
@ -302,9 +300,6 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
model_split_percents = [0.5, 0.7, 0.8]
|
model_split_percents = [0.5, 0.7, 0.8]
|
||||||
|
|
||||||
# used in `test_torch_compile`
|
|
||||||
_torch_compile_test_ckpt = "ibm/PowerMoE-3b"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = GraniteMoeModelTester(self)
|
self.model_tester = GraniteMoeModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=GraniteMoeConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=GraniteMoeConfig, hidden_size=37)
|
||||||
@ -422,46 +417,6 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@require_bitsandbytes
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@require_read_token
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
"""
|
|
||||||
Overwritting the common test as the test is flaky on tiny models
|
|
||||||
"""
|
|
||||||
model = GraniteMoeForCausalLM.from_pretrained(
|
|
||||||
"ibm-granite/granitemoe-3b",
|
|
||||||
load_in_4bit=True,
|
|
||||||
device_map={"": 0},
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granitemoe-3b")
|
|
||||||
|
|
||||||
texts = ["hi", "Hello this is a very long sentence"]
|
|
||||||
|
|
||||||
tokenizer.padding_side = "right"
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
|
||||||
|
|
||||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_native = tokenizer.batch_decode(output_native)
|
|
||||||
|
|
||||||
model = GraniteMoeForCausalLM.from_pretrained(
|
|
||||||
"ibm-granite/granitemoe-3b",
|
|
||||||
load_in_4bit=True,
|
|
||||||
device_map={"": 0},
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
)
|
|
||||||
|
|
||||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
|
||||||
|
|
||||||
self.assertListEqual(output_native, output_fa_2)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
|
@ -770,13 +770,6 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
|||||||
def test_custom_4d_attention_mask(self):
|
def test_custom_4d_attention_mask(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
reason="IDEFICS has specific requirements for working with inputs embeds like passing also the ids and pixels"
|
|
||||||
)
|
|
||||||
@parameterized.expand([(1,), (2,)])
|
|
||||||
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||||
def test_generate_compile_fullgraph(self):
|
def test_generate_compile_fullgraph(self):
|
||||||
pass
|
pass
|
||||||
|
@ -20,7 +20,6 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import pytest
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -420,50 +419,6 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
|||||||
def test_flash_attn_2_fp32_ln(self):
|
def test_flash_attn_2_fp32_ln(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@pytest.mark.generate
|
|
||||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
|
||||||
# overwrite because IDEFICS needs ids and embeds at the input to be not None
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
|
||||||
|
|
||||||
# Ignore:
|
|
||||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
|
||||||
# which would cause a mismatch),
|
|
||||||
config.pad_token_id = config.eos_token_id = -1
|
|
||||||
config.is_decoder = True
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
input_ids = inputs_dict.pop("input_ids")
|
|
||||||
|
|
||||||
# Traditional way of generating text
|
|
||||||
outputs_from_ids = model.generate(
|
|
||||||
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
|
|
||||||
)
|
|
||||||
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
|
||||||
|
|
||||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
|
||||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
|
||||||
outputs_from_embeds = model.generate(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
max_new_tokens=5,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
output_scores=True,
|
|
||||||
)
|
|
||||||
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
|
|
||||||
|
|
||||||
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
|
|
||||||
# same, but the logits will almost surely be different)
|
|
||||||
random_embeds = torch.rand_like(inputs_embeds)
|
|
||||||
outputs_from_rand_embeds = model.generate(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds=random_embeds,
|
|
||||||
max_new_tokens=5,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
output_scores=True,
|
|
||||||
)
|
|
||||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
|
||||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
|
||||||
|
|
||||||
# We need to override as we need to prepare such that the image token is the last token
|
# We need to override as we need to prepare such that the image token is the last token
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
@ -19,7 +19,6 @@ import gc
|
|||||||
import unittest
|
import unittest
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import pytest
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -180,10 +179,6 @@ class Idefics3ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_inputs_embeds_matches_input_ids(self):
|
def test_inputs_embeds_matches_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Model does not support padding right")
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Model does not support padding right")
|
@unittest.skip(reason="Model does not support padding right")
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_padding_right(self):
|
||||||
pass
|
pass
|
||||||
@ -337,10 +332,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
|||||||
def test_inputs_embeds():
|
def test_inputs_embeds():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Model does not support padding right")
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Model does not support padding right")
|
@unittest.skip(reason="Model does not support padding right")
|
||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_padding_right(self):
|
||||||
pass
|
pass
|
||||||
@ -367,50 +358,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
|||||||
def test_flash_attn_2_fp32_ln(self):
|
def test_flash_attn_2_fp32_ln(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@pytest.mark.generate
|
|
||||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
|
||||||
# overwrite because IDEFICS needs ids and embeds at the input to be not None
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
|
||||||
|
|
||||||
# Ignore:
|
|
||||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
|
||||||
# which would cause a mismatch),
|
|
||||||
config.pad_token_id = config.eos_token_id = -1
|
|
||||||
config.is_decoder = True
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
input_ids = inputs_dict.pop("input_ids")
|
|
||||||
|
|
||||||
# Traditional way of generating text
|
|
||||||
outputs_from_ids = model.generate(
|
|
||||||
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
|
|
||||||
)
|
|
||||||
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
|
||||||
|
|
||||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
|
||||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
|
||||||
outputs_from_embeds = model.generate(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
max_new_tokens=5,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
output_scores=True,
|
|
||||||
)
|
|
||||||
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
|
|
||||||
|
|
||||||
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
|
|
||||||
# same, but the logits will almost surely be different)
|
|
||||||
random_embeds = torch.rand_like(inputs_embeds)
|
|
||||||
outputs_from_rand_embeds = model.generate(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds=random_embeds,
|
|
||||||
max_new_tokens=5,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
output_scores=True,
|
|
||||||
)
|
|
||||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
|
||||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
|
||||||
|
|
||||||
# We need to override as we need to prepare such that the image token is the last token
|
# We need to override as we need to prepare such that the image token is the last token
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@ -526,31 +473,6 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
|||||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
|
||||||
# overwrite because IDEFICS needs ids and embeds at the input to be not None
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
|
||||||
|
|
||||||
wte = model.get_input_embeddings()
|
|
||||||
|
|
||||||
input_ids = inputs["input_ids"]
|
|
||||||
# some models infer position ids/attn mask differently when input ids
|
|
||||||
# by check if pad_token let's make sure no padding is in input ids
|
|
||||||
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
|
|
||||||
input_ids[input_ids == pad_token_id] = not_pad_token_id
|
|
||||||
del inputs["input_ids"]
|
|
||||||
inputs_embeds = wte(input_ids)
|
|
||||||
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
|
|
||||||
out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||||
|
@ -539,93 +539,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# with attention mask
|
# with attention mask
|
||||||
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
r"""
|
|
||||||
Overriding the test_flash_attn_2_generate_padding_right test as the Jamba model, like Mixtral, doesn't support
|
|
||||||
right padding + use cache with FA2
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
||||||
|
|
||||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
r"""
|
|
||||||
Overriding the test_flash_attn_2_generate_use_cache test as the Jamba model, like Mixtral, doesn't support
|
|
||||||
right padding + use cache with FA2
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# NOTE: Jamba does not support right padding + use_cache with FA2.
|
|
||||||
dummy_attention_mask[:, -1] = 1
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
"""Testing suite for the PyTorch JetMoe model."""
|
"""Testing suite for the PyTorch JetMoe model."""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -377,85 +376,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
||||||
|
|
||||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# NOTE: JetMoe apparently does not support right padding + use_cache with FA2.
|
|
||||||
dummy_attention_mask[:, -1] = 1
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -438,12 +438,6 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking ipnut args because KOSMOS-2 has `generate()` overwritten"
|
|
||||||
)
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model_name = "microsoft/kosmos-2-patch14-224"
|
model_name = "microsoft/kosmos-2-patch14-224"
|
||||||
|
@ -26,7 +26,6 @@ from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_avail
|
|||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
require_bitsandbytes,
|
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
@ -316,9 +315,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
model_split_percents = [0.5, 0.7, 0.8]
|
model_split_percents = [0.5, 0.7, 0.8]
|
||||||
|
|
||||||
# used in `test_torch_compile`
|
|
||||||
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
|
|
||||||
|
|
||||||
# used in `test_torch_compile_for_training`
|
# used in `test_torch_compile_for_training`
|
||||||
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
|
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
@ -585,43 +581,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
|
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@require_bitsandbytes
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@require_read_token
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
"""
|
|
||||||
Overwritting the common test as the test is flaky on tiny models
|
|
||||||
"""
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
|
||||||
"meta-llama/Llama-2-7b-hf",
|
|
||||||
load_in_4bit=True,
|
|
||||||
device_map={"": 0},
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
||||||
|
|
||||||
texts = ["hi", "Hello this is a very long sentence"]
|
|
||||||
|
|
||||||
tokenizer.padding_side = "right"
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
|
||||||
|
|
||||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_native = tokenizer.batch_decode(output_native)
|
|
||||||
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
|
||||||
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
|
|
||||||
)
|
|
||||||
|
|
||||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
|
||||||
|
|
||||||
self.assertListEqual(output_native, output_fa_2)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
|
@ -204,8 +204,8 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||||
@parameterized.expand([(1,), (2,)])
|
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||||
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
|
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||||
@ -276,12 +276,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
reason="Mamba2 does not support generating with input embeddings (custom cache_position computation)"
|
|
||||||
)
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
|
@ -21,7 +21,6 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Audio, load_dataset
|
from datasets import Audio, load_dataset
|
||||||
from packaging import version
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
|
|
||||||
@ -745,22 +744,6 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_sdpa_can_compile_dynamic(self):
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# For now, Let's focus only on GPU for `torch.compile`
|
|
||||||
@slow
|
|
||||||
@require_torch_gpu
|
|
||||||
def test_torch_compile(self):
|
|
||||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
|
||||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
n_iter = 3
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
model = model_class(config).to(torch_device)
|
|
||||||
model.forward = torch.compile(model.forward)
|
|
||||||
for i in range(n_iter):
|
|
||||||
_ = model(inputs_dict["input_values"].to(torch_device))
|
|
||||||
|
|
||||||
@is_flaky()
|
@is_flaky()
|
||||||
def test_batching_equivalence(self):
|
def test_batching_equivalence(self):
|
||||||
super().test_batching_equivalence()
|
super().test_batching_equivalence()
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
"""Testing suite for the PyTorch Mistral model."""
|
"""Testing suite for the PyTorch Mistral model."""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -416,85 +415,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
||||||
|
|
||||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# NOTE: Mistral apparently does not support right padding + use_cache with FA2.
|
|
||||||
dummy_attention_mask[:, -1] = 1
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch Mixtral model."""
|
"""Testing suite for the PyTorch Mixtral model."""
|
||||||
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -415,85 +414,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
||||||
|
|
||||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# NOTE: Mixtral apparently does not support right padding + use_cache with FA2.
|
|
||||||
dummy_attention_mask[:, -1] = 1
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -126,7 +126,6 @@ class MllamaForCausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
|||||||
all_generative_model_classes = (MllamaForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (MllamaForCausalLM,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
_torch_compile_test_ckpt = "nltpt/Llama-3.2-11B-Vision"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = MllamaText2TextModelTester(self)
|
self.model_tester = MllamaText2TextModelTester(self)
|
||||||
|
@ -560,7 +560,7 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
return config, input_ids, attention_mask, inputs_dict
|
return config, input_ids, attention_mask, inputs_dict
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_generate(self, batch_size=2):
|
def prepare_config_and_inputs_for_generate(self, batch_size=2):
|
||||||
config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate()
|
config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size)
|
||||||
|
|
||||||
# Make sure we only return `input_ids`.
|
# Make sure we only return `input_ids`.
|
||||||
# Note that audio_codes will still be generated internally, so the ability to test audio codes is still there.
|
# Note that audio_codes will still be generated internally, so the ability to test audio codes is still there.
|
||||||
@ -591,9 +591,11 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
[expected_shape] * len(iter_hidden_states),
|
[expected_shape] * len(iter_hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||||
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
|
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
|
||||||
super()._check_outputs(output, input_ids, config, use_cache=True, num_return_sequences=num_return_sequences)
|
super()._check_outputs(
|
||||||
|
output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams
|
||||||
|
)
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_hidden_states_for_generate(
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||||
@ -655,59 +657,6 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.generate
|
|
||||||
@parameterized.expand([(1,), (2,)])
|
|
||||||
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, input_ids, _, inputs_dict = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
generation_kwargs = {
|
|
||||||
"return_dict_in_generate": True,
|
|
||||||
"output_scores": True,
|
|
||||||
"num_beams": num_beams,
|
|
||||||
"do_sample": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Traditional way of generating text
|
|
||||||
outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs, **inputs_dict)
|
|
||||||
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
|
||||||
|
|
||||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
|
||||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
|
||||||
outputs_from_embeds = model.generate(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
max_new_tokens=5,
|
|
||||||
**generation_kwargs,
|
|
||||||
**inputs_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
|
|
||||||
# same, but the logits will almost surely be different)
|
|
||||||
random_embeds = torch.rand_like(inputs_embeds)
|
|
||||||
outputs_from_rand_embeds = model.generate(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds=random_embeds,
|
|
||||||
max_new_tokens=5,
|
|
||||||
**generation_kwargs,
|
|
||||||
**inputs_dict,
|
|
||||||
)
|
|
||||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
|
||||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
|
||||||
|
|
||||||
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
|
||||||
outputs_from_embeds_wo_ids = model.generate(
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
max_new_tokens=5,
|
|
||||||
**generation_kwargs,
|
|
||||||
**inputs_dict,
|
|
||||||
)
|
|
||||||
self.assertListEqual(
|
|
||||||
outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
|
|
||||||
outputs_from_embeds_wo_ids.sequences.tolist(),
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.skip(reason="Continuing from past key values is not straightforward as we're dealing with 3 inputs")
|
@unittest.skip(reason="Continuing from past key values is not straightforward as we're dealing with 3 inputs")
|
||||||
def test_generate_continue_from_past_key_values(self):
|
def test_generate_continue_from_past_key_values(self):
|
||||||
pass
|
pass
|
||||||
|
@ -576,9 +576,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
# The small MT5 model needs higher percentages for CPU/MP tests
|
# The small MT5 model needs higher percentages for CPU/MP tests
|
||||||
model_split_percents = [0.5, 0.8, 0.9]
|
model_split_percents = [0.5, 0.8, 0.9]
|
||||||
|
|
||||||
# used in `test_torch_compile`
|
|
||||||
_torch_compile_test_ckpt = "google/mt5-small"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = MT5ModelTester(self)
|
self.model_tester = MT5ModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37)
|
self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37)
|
||||||
|
@ -450,144 +450,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
|
|
||||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
|
|
||||||
def test_flash_attn_2_generate_left_padding(self):
|
|
||||||
# Ignore copy
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# make sure we do left padding
|
|
||||||
dummy_attention_mask[:, :-1] = 0
|
|
||||||
dummy_attention_mask[:, -1:] = 1
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
out_fa = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out, out_fa))
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
# Ignore copy
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# make sure we do right padding
|
|
||||||
dummy_attention_mask[:, :-1] = 1
|
|
||||||
dummy_attention_mask[:, -1:] = 0
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
out_fa = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out, out_fa))
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
@slow
|
@slow
|
||||||
@ -1585,149 +1447,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
|
|
||||||
def test_flash_attn_2_generate_left_padding(self):
|
|
||||||
# Ignore copy
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask")
|
|
||||||
if dummy_attention_mask is None:
|
|
||||||
dummy_attention_mask = torch.ones_like(dummy_input)
|
|
||||||
|
|
||||||
# make sure we do left padding
|
|
||||||
dummy_attention_mask[:, :-1] = 0
|
|
||||||
dummy_attention_mask[:, -1:] = 1
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
out_fa = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out, out_fa))
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
# Ignore copy
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask")
|
|
||||||
if dummy_attention_mask is None:
|
|
||||||
dummy_attention_mask = torch.ones_like(dummy_input)
|
|
||||||
# make sure we do right padding
|
|
||||||
dummy_attention_mask[:, :-1] = 1
|
|
||||||
dummy_attention_mask[:, -1:] = 0
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
out_fa = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out, out_fa))
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
def test_sdpa_can_dispatch_composite_models(self):
|
def test_sdpa_can_dispatch_composite_models(self):
|
||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
|
@ -1437,149 +1437,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding
|
|
||||||
def test_flash_attn_2_generate_left_padding(self):
|
|
||||||
# Ignore copy
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask")
|
|
||||||
if dummy_attention_mask is None:
|
|
||||||
dummy_attention_mask = torch.ones_like(dummy_input)
|
|
||||||
|
|
||||||
# make sure we do left padding
|
|
||||||
dummy_attention_mask[:, :-1] = 0
|
|
||||||
dummy_attention_mask[:, -1:] = 1
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
out_fa = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out, out_fa))
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
# Ignore copy
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask")
|
|
||||||
if dummy_attention_mask is None:
|
|
||||||
dummy_attention_mask = torch.ones_like(dummy_input)
|
|
||||||
# make sure we do right padding
|
|
||||||
dummy_attention_mask[:, :-1] = 1
|
|
||||||
dummy_attention_mask[:, -1:] = 0
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
out_fa = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out, out_fa))
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
def test_sdpa_can_dispatch_composite_models(self):
|
def test_sdpa_can_dispatch_composite_models(self):
|
||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
|
@ -92,8 +92,6 @@ class NemotronModelTest(GemmaModelTest):
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
|
|
||||||
# used in `test_torch_compile`
|
|
||||||
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
|
|
||||||
# used in `test_torch_compile_for_training`
|
# used in `test_torch_compile_for_training`
|
||||||
_torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None
|
_torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
|
@ -346,10 +346,6 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="TODO (@joao): fix me -- failing to produce similar results")
|
|
||||||
def test_static_cache_matches_dynamic(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||||
def test_flash_attn_2_fp32_ln(self):
|
def test_flash_attn_2_fp32_ln(self):
|
||||||
pass
|
pass
|
||||||
|
@ -17,15 +17,11 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import PhiConfig, is_torch_available, set_seed
|
from transformers import PhiConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
|
||||||
require_flash_attn,
|
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@ -468,43 +464,6 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
||||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@require_bitsandbytes
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
"""
|
|
||||||
Overwritting the common test as the test is flaky on tiny models
|
|
||||||
"""
|
|
||||||
model = PhiForCausalLM.from_pretrained(
|
|
||||||
"microsoft/phi-1",
|
|
||||||
load_in_4bit=True,
|
|
||||||
device_map={"": 0},
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
|
|
||||||
|
|
||||||
texts = ["hi", "Hello this is a very long sentence"]
|
|
||||||
|
|
||||||
tokenizer.padding_side = "right"
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
|
||||||
|
|
||||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_native = tokenizer.batch_decode(output_native)
|
|
||||||
|
|
||||||
model = PhiForCausalLM.from_pretrained(
|
|
||||||
"microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
|
|
||||||
)
|
|
||||||
|
|
||||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
||||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
|
||||||
|
|
||||||
self.assertListEqual(output_native, output_fa_2)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
"""Testing suite for the PyTorch Qwen2 model."""
|
"""Testing suite for the PyTorch Qwen2 model."""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -428,85 +427,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
||||||
|
|
||||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# NOTE: Qwen2 apparently does not support right padding + use_cache with FA2.
|
|
||||||
dummy_attention_mask[:, -1] = 1
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
"""Testing suite for the PyTorch Qwen2MoE model."""
|
"""Testing suite for the PyTorch Qwen2MoE model."""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -453,85 +452,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
||||||
|
|
||||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# NOTE: Qwen2Moe apparently does not support right padding + use_cache with FA2.
|
|
||||||
dummy_attention_mask[:, -1] = 1
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -301,10 +301,6 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Generate needs input ids")
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="CPU offload is not yet supported")
|
@unittest.skip(reason="CPU offload is not yet supported")
|
||||||
def test_cpu_offload(self):
|
def test_cpu_offload(self):
|
||||||
pass
|
pass
|
||||||
|
@ -420,10 +420,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="RecurrentGemma does not support generating with input embeddings (missing position_ids)")
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@slow
|
@slow
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch Starcoder2 model."""
|
"""Testing suite for the PyTorch Starcoder2 model."""
|
||||||
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -404,85 +403,6 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
||||||
|
|
||||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# NOTE: Starcoder2 apparently does not support right padding + use_cache with FA2.
|
|
||||||
dummy_attention_mask[:, -1] = 1
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -580,9 +580,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
# The small T5 model needs higher percentages for CPU/MP tests
|
# The small T5 model needs higher percentages for CPU/MP tests
|
||||||
model_split_percents = [0.5, 0.8, 0.9]
|
model_split_percents = [0.5, 0.8, 0.9]
|
||||||
|
|
||||||
# used in `test_torch_compile`
|
|
||||||
_torch_compile_test_ckpt = "google-t5/t5-small"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = T5ModelTester(self)
|
self.model_tester = T5ModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
|
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
|
||||||
|
@ -317,9 +317,6 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
# The small UMT5 model needs higher percentages for CPU/MP tests
|
# The small UMT5 model needs higher percentages for CPU/MP tests
|
||||||
model_split_percents = [0.5, 0.8, 0.9]
|
model_split_percents = [0.5, 0.8, 0.9]
|
||||||
|
|
||||||
# used in `test_torch_compile`
|
|
||||||
_torch_compile_test_ckpt = "google/umt5-small"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = UMT5ModelTester(self)
|
self.model_tester = UMT5ModelTester(self)
|
||||||
|
|
||||||
|
@ -1574,59 +1574,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
)
|
)
|
||||||
assert isinstance(pred_ids, expected_output_type)
|
assert isinstance(pred_ids, expected_output_type)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_reuse_cache(self):
|
|
||||||
max_new_tokens = 2
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name][..., :10]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# run generate once to get filled cache
|
|
||||||
output = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
past_key_values = output.past_key_values
|
|
||||||
|
|
||||||
# Try to continue generation from where we left, given that we have more than 1 new token to process
|
|
||||||
# e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
decoder_input_ids=output.sequences,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_labels_sequence_max_length_correct(self):
|
def test_labels_sequence_max_length_correct(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@ -3961,11 +3908,6 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
|||||||
# generate only works with input ids for whisper
|
# generate only works with input ids for whisper
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Generate needs input ids")
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
|
||||||
# generate only works with input ids for whisper
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Decoder can't keep attention grads")
|
@unittest.skip(reason="Decoder can't keep attention grads")
|
||||||
def test_retain_grad_hidden_states_attentions(self):
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
return
|
return
|
||||||
@ -3974,18 +3916,6 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
|||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
reason="FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
|
||||||
)
|
|
||||||
def test_flash_attn_2_generate_reuse_cache(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
|
||||||
)
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
||||||
)
|
)
|
||||||
|
@ -542,93 +542,6 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# with attention mask
|
# with attention mask
|
||||||
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
_ = model(dummy_input, attention_mask=dummy_attention_mask)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
r"""
|
|
||||||
Overriding the test_flash_attn_2_generate_padding_right test as the Zamba model, like Mixtral, doesn't support
|
|
||||||
right padding + use cache with FA2
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
||||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
||||||
|
|
||||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
r"""
|
|
||||||
Overriding the test_flash_attn_2_generate_use_cache test as the Zamba model, like Mixtral, doesn't support
|
|
||||||
right padding + use cache with FA2
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# NOTE: Zamba does not support right padding + use_cache with FA2.
|
|
||||||
dummy_attention_mask[:, -1] = 1
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -22,7 +22,6 @@ import os.path
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -37,10 +36,7 @@ import transformers
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForSeq2SeqLM,
|
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
|
||||||
GenerationConfig,
|
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@ -86,7 +82,6 @@ from transformers.testing_utils import (
|
|||||||
require_deepspeed,
|
require_deepspeed,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_non_xpu,
|
require_non_xpu,
|
||||||
require_read_token,
|
|
||||||
require_safetensors,
|
require_safetensors,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
@ -3000,71 +2995,6 @@ class ModelTesterMixin:
|
|||||||
)[0]
|
)[0]
|
||||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||||
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if model_class.__name__ not in [
|
|
||||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
|
||||||
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
|
|
||||||
]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
model_forward_args = inspect.signature(model.forward).parameters
|
|
||||||
if any(argument not in model_forward_args for argument in ["inputs_embeds", "position_ids"]):
|
|
||||||
self.skipTest(reason="This model doesn't use `inputs_embeds` or `position_ids`.")
|
|
||||||
has_inputs_embeds_forwarding = "inputs_embeds" in set(
|
|
||||||
inspect.signature(model.prepare_inputs_for_generation).parameters.keys()
|
|
||||||
)
|
|
||||||
if not has_inputs_embeds_forwarding:
|
|
||||||
self.skipTest(reason="This model doesn't support `inputs_embeds` passed to `generate`.")
|
|
||||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
|
||||||
|
|
||||||
# VLMs can't generate with embeds and pixels at the same time. We expect the user to pass merged
|
|
||||||
# embeds already
|
|
||||||
if model_class.__name__ in get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES):
|
|
||||||
inputs.pop("pixel_values", None)
|
|
||||||
inputs.pop("pixel_values_videos", None)
|
|
||||||
inputs.pop("pixel_values_images", None)
|
|
||||||
|
|
||||||
wte = model.get_input_embeddings()
|
|
||||||
if not self.is_encoder_decoder:
|
|
||||||
input_ids = inputs["input_ids"]
|
|
||||||
# some models infer position ids/attn mask differently when input ids
|
|
||||||
# by check if pad_token let's make sure no padding is in input ids
|
|
||||||
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
|
|
||||||
input_ids[input_ids == pad_token_id] = not_pad_token_id
|
|
||||||
del inputs["input_ids"]
|
|
||||||
inputs_embeds = wte(input_ids)
|
|
||||||
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)[:, -2:]
|
|
||||||
out_embeds = model.generate(inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
|
|
||||||
else:
|
|
||||||
encoder_input_ids = inputs["input_ids"]
|
|
||||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
|
||||||
encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
|
|
||||||
decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
|
|
||||||
del inputs["input_ids"]
|
|
||||||
inputs.pop("decoder_input_ids", None)
|
|
||||||
inputs_embeds = wte(encoder_input_ids)
|
|
||||||
decoder_inputs_embeds = wte(decoder_input_ids)
|
|
||||||
out_ids = model.generate(
|
|
||||||
input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs, max_new_tokens=2
|
|
||||||
)[:, -2:]
|
|
||||||
out_embeds = model.generate(
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
||||||
**inputs,
|
|
||||||
max_new_tokens=2,
|
|
||||||
)
|
|
||||||
# NOTE: this test changes the order of FP ops, there may be tiny differences in the output
|
|
||||||
number_of_different_tokens = (out_ids != out_embeds).sum()
|
|
||||||
max_differences = int(out_ids.shape[0] * out_ids.shape[1] * 0.1)
|
|
||||||
self.assertTrue(number_of_different_tokens <= max_differences) # accept up to 10% mismatch
|
|
||||||
|
|
||||||
@require_non_xpu
|
@require_non_xpu
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_gpu
|
||||||
def test_multi_gpu_data_parallel_forward(self):
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
@ -3857,102 +3787,6 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
@is_flaky()
|
|
||||||
def test_flash_attn_2_generate_left_padding(self):
|
|
||||||
if not self.has_attentions:
|
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# make sure we do left padding
|
|
||||||
dummy_attention_mask[:, :-1] = 0
|
|
||||||
dummy_attention_mask[:, -1:] = 1
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
out_fa = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out, out_fa))
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@is_flaky()
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
|
||||||
if not self.has_attentions:
|
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
# make sure we do right padding
|
|
||||||
dummy_attention_mask[:, :-1] = 1
|
|
||||||
dummy_attention_mask[:, -1:] = 0
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
out_fa = model.generate(
|
|
||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(out, out_fa))
|
|
||||||
|
|
||||||
def test_attn_implementation_composite_models(self):
|
def test_attn_implementation_composite_models(self):
|
||||||
"""
|
"""
|
||||||
Tests if composite models can receive a dict object as attn_implementation, where each key should be
|
Tests if composite models can receive a dict object as attn_implementation, where each key should be
|
||||||
@ -4525,65 +4359,6 @@ class ModelTesterMixin:
|
|||||||
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
|
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_use_cache(self):
|
|
||||||
if not self.has_attentions:
|
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
|
||||||
|
|
||||||
max_new_tokens = 30
|
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# Just test that a large cache works as expected
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate with one batch only to test generation when attention mask will be None
|
|
||||||
# when real inputs are used, because there is no padding. See issue #32237 for more
|
|
||||||
dummy_input = dummy_input[:1, ...]
|
|
||||||
dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
attention_mask=dummy_attention_mask,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
@ -4640,62 +4415,6 @@ class ModelTesterMixin:
|
|||||||
if not has_fa2:
|
if not has_fa2:
|
||||||
raise ValueError("The FA2 model should have FA2 layers")
|
raise ValueError("The FA2 model should have FA2 layers")
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@mark.flash_attn_test
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_generate_reuse_cache(self):
|
|
||||||
if not self.has_attentions:
|
|
||||||
self.skipTest(reason="Model architecture does not support attentions")
|
|
||||||
|
|
||||||
max_new_tokens = 2
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
||||||
dummy_input = dummy_input.to(torch.float16)
|
|
||||||
|
|
||||||
# make sure that all models have enough positions for generation
|
|
||||||
if hasattr(config, "max_position_embeddings"):
|
|
||||||
config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
|
|
||||||
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
).to(torch_device)
|
|
||||||
|
|
||||||
# run generate once to get filled cache
|
|
||||||
output = model.generate(
|
|
||||||
dummy_input,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
past_key_values = output.past_key_values
|
|
||||||
|
|
||||||
# Try to continue generation from where we left, given that we have more than 1 new token to process
|
|
||||||
# e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
|
|
||||||
dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1)
|
|
||||||
_ = model.generate(
|
|
||||||
dummy_input_updated,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@ -4999,82 +4718,6 @@ class ModelTesterMixin:
|
|||||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||||
|
|
||||||
def test_static_cache_matches_dynamic(self):
|
|
||||||
"""
|
|
||||||
Tests that generating with static cache give almost same results as with dynamic cache.
|
|
||||||
This test does not compile the model and check only logits similarity for numerical precision
|
|
||||||
errors.
|
|
||||||
"""
|
|
||||||
if len(self.all_generative_model_classes) == 0:
|
|
||||||
self.skipTest(
|
|
||||||
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
|
|
||||||
)
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if not model_class._supports_static_cache:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support static cache")
|
|
||||||
|
|
||||||
if not model_class._supports_cache_class:
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support cache class")
|
|
||||||
|
|
||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
|
|
||||||
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
|
|
||||||
|
|
||||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
dynamic_out = model.generate(
|
|
||||||
**inputs, do_sample=False, max_new_tokens=10, output_logits=True, return_dict_in_generate=True
|
|
||||||
)
|
|
||||||
static_out = model.generate(
|
|
||||||
**inputs,
|
|
||||||
do_sample=False,
|
|
||||||
max_new_tokens=10,
|
|
||||||
cache_implementation="static",
|
|
||||||
output_logits=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(dynamic_out.logits[0], static_out.logits[0], rtol=1e-3, atol=1e-4))
|
|
||||||
|
|
||||||
# For now, Let's focus only on GPU for `torch.compile`
|
|
||||||
@slow
|
|
||||||
@require_torch_accelerator
|
|
||||||
@require_read_token
|
|
||||||
def test_torch_compile(self):
|
|
||||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
|
||||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
|
||||||
torch.compiler.reset()
|
|
||||||
if not hasattr(self, "_torch_compile_test_ckpt"):
|
|
||||||
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
|
|
||||||
ckpt = self._torch_compile_test_ckpt
|
|
||||||
revision = "main" if not hasattr(self, "_torch_compile_test_revision") else self._torch_compile_test_revision
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
batch_size = 1
|
|
||||||
n_iter = 3
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
|
||||||
if self.is_encoder_decoder:
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
model.generation_config.max_new_tokens = 4
|
|
||||||
|
|
||||||
model.generation_config.cache_implementation = "static"
|
|
||||||
model.forward = torch.compile(model.forward, mode="reduce-overhead")
|
|
||||||
|
|
||||||
input_text = "Why dogs are cute?"
|
|
||||||
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to(torch_device)
|
|
||||||
|
|
||||||
for i in range(n_iter):
|
|
||||||
_ = model.generate(**input_ids, do_sample=False)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_torch_compile_for_training(self):
|
def test_torch_compile_for_training(self):
|
||||||
@ -5118,74 +4761,6 @@ class ModelTesterMixin:
|
|||||||
for name, param in model._orig_mod.named_parameters():
|
for name, param in model._orig_mod.named_parameters():
|
||||||
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
|
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
@slow
|
|
||||||
@require_torch_gpu # Testing cuda graphs.
|
|
||||||
@require_read_token
|
|
||||||
def test_compile_cuda_graph_time(self):
|
|
||||||
if version.parse(torch.__version__) < version.parse("2.3"):
|
|
||||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
|
||||||
|
|
||||||
# TODO felix: All models supporting `StaticCache` or `torch.compile` should be tested.
|
|
||||||
# At the moment, only llama, gemma and gemma2 are tested here!
|
|
||||||
if not hasattr(self, "_torch_compile_test_ckpt"):
|
|
||||||
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
|
|
||||||
ckpt = self._torch_compile_test_ckpt
|
|
||||||
revision = "main" if not hasattr(self, "_torch_compile_test_revision") else self._torch_compile_test_revision
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
|
||||||
if self.is_encoder_decoder:
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
|
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_implementation = "static"
|
|
||||||
if model.config.model_type == "gemma2":
|
|
||||||
cache_implementation = "hybrid"
|
|
||||||
|
|
||||||
new_tokens = 50
|
|
||||||
gen_config = GenerationConfig(
|
|
||||||
max_new_tokens=new_tokens,
|
|
||||||
min_new_tokens=new_tokens,
|
|
||||||
use_cache=True,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
num_beams=1,
|
|
||||||
do_sample=False,
|
|
||||||
eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
|
|
||||||
)
|
|
||||||
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
|
|
||||||
|
|
||||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
|
||||||
|
|
||||||
inp = tokenizer("Why cats are cute?", return_tensors="pt").to(torch_device)
|
|
||||||
|
|
||||||
# First run: the first run warms up each graph, which does things like CuBlas or Triton benchmarking
|
|
||||||
start = time.perf_counter()
|
|
||||||
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
|
|
||||||
end = time.perf_counter()
|
|
||||||
graph_warmup_time = end - start
|
|
||||||
|
|
||||||
# Second run: CUDA Graph recording, and replays it
|
|
||||||
start = time.perf_counter()
|
|
||||||
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
|
|
||||||
end = time.perf_counter()
|
|
||||||
record_time = end - start
|
|
||||||
|
|
||||||
# Finally: we hit the optimized, CUDA Graph replay path
|
|
||||||
start = time.perf_counter()
|
|
||||||
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
|
|
||||||
end = time.perf_counter()
|
|
||||||
opt_time = end - start
|
|
||||||
|
|
||||||
# For the recording step, we expect only two cuda graphs and this step should be much faster than the first.
|
|
||||||
self.assertTrue(record_time < 0.15 * graph_warmup_time)
|
|
||||||
self.assertTrue(opt_time < record_time)
|
|
||||||
|
|
||||||
def test_forward_with_num_logits_to_keep(self):
|
def test_forward_with_num_logits_to_keep(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||||
|
Loading…
Reference in New Issue
Block a user