Generate tests: modality-agnostic input preparation (#33685)

This commit is contained in:
Joao Gante 2024-10-03 14:01:24 +01:00 committed by GitHub
parent f2bf4fcf3d
commit d29738f5b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 241 additions and 906 deletions

File diff suppressed because it is too large Load Diff

View File

@ -283,28 +283,6 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
return False
# overwrite from GenerationTesterMixin to solve problem
# with conflicting random seeds
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.attention_type = "original_full"
input_ids = inputs_dict.pop(self.input_name)
_ = inputs_dict.pop("attention_mask", None)
_ = inputs_dict.pop("decoder_input_ids", None)
_ = inputs_dict.pop("decoder_attention_mask", None)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
# cut to half length & take max batch_size 3
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:batch_size, :sequence_length]
attention_mask = attention_mask[:batch_size, :sequence_length]
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, inputs_dict
def setUp(self):
self.model_tester = BigBirdPegasusModelTester(self)
self.config_tester = ConfigTester(self, config_class=BigBirdPegasusConfig)
@ -485,6 +463,13 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
def test_load_save_without_tied_weights(self):
pass
def test_generate_with_head_masking(self):
# overwritten to temporarily switch the attention type to `original_full`
original_self_attention_type = self.model_tester.attention_type
self.model_tester.attention_type = "original_full"
super().test_generate_with_head_masking()
self.model_tester.attention_type = original_self_attention_type
@require_torch
@require_sentencepiece

View File

@ -116,7 +116,7 @@ class ChameleonModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
sequence_labels = None
token_labels = None

View File

@ -95,7 +95,7 @@ class CohereModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -123,7 +123,6 @@ class DacModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_headmasking = False
test_resize_embeddings = False
pipeline_model_mapping = {"feature-extraction": DacModel} if is_torch_available() else {}
input_name = "input_values"
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# model does not have attention and does not support returning hidden states

View File

@ -141,7 +141,6 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
test_headmasking = False
test_resize_embeddings = False
pipeline_model_mapping = {"feature-extraction": EncodecModel} if is_torch_available() else {}
input_name = "input_values"
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# model does not have attention and does not support returning hidden states

View File

@ -119,7 +119,7 @@ class GemmaModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -106,7 +106,7 @@ class GraniteModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -105,7 +105,7 @@ class GraniteMoeModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -338,13 +338,11 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_global_attention(*config_and_inputs)
def _get_input_ids_and_config(self, batch_size=2):
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(
self, batch_size=batch_size
)
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
config, inputs_dict = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
# LED computes attention scores based on mask indices if `is_global`
inputs_dict.pop("global_attention_mask")
return config, input_ids, attention_mask, inputs_dict
return config, inputs_dict
# LEDForSequenceClassification does not support inputs_embeds
def test_inputs_embeds(self):

View File

@ -112,7 +112,7 @@ class LlamaModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -170,7 +170,6 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
test_headmasking = False
test_resize_embeddings = False
test_torchscript = False
input_name = "input_values"
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# model does support returning hidden states

View File

@ -112,7 +112,7 @@ class MistralModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -108,7 +108,7 @@ class MixtralModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -60,10 +60,6 @@ if is_torch_available():
MusicgenModel,
set_seed,
)
from transformers.generation import (
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
)
def _config_zero_init(config):
@ -124,6 +120,7 @@ class MusicgenDecoderTester:
pad_token_id=99,
bos_token_id=99,
num_codebooks=4,
audio_channels=1,
):
self.parent = parent
self.batch_size = batch_size
@ -141,6 +138,7 @@ class MusicgenDecoderTester:
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.num_codebooks = num_codebooks
self.audio_channels = audio_channels
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
@ -166,6 +164,7 @@ class MusicgenDecoderTester:
bos_token_id=self.bos_token_id,
num_codebooks=self.num_codebooks,
tie_word_embeddings=False,
audio_channels=self.audio_channels,
)
return config
@ -282,47 +281,15 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
def test_tied_weights_keys(self):
pass
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
_ = inputs_dict.pop("attention_mask", None)
inputs_dict = {
k: v[:batch_size, ...]
for k, v in inputs_dict.items()
if "head_mask" not in k and isinstance(v, torch.Tensor)
}
# take max batch_size
sequence_length = input_ids.shape[-1]
input_ids = input_ids[: batch_size * config.num_codebooks, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, inputs_dict
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
inputs_dict={},
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
original_audio_channels = self.model_tester.audio_channels
self.model_tester.audio_channels = 2
super().test_greedy_generate_dict_outputs()
self.model_tester.audio_channels = original_audio_channels
@require_flash_attn
@require_torch_gpu
@ -998,6 +965,7 @@ class MusicgenTester:
num_codebooks=4,
num_filters=4,
codebook_size=128,
audio_channels=1,
):
self.parent = parent
self.batch_size = batch_size
@ -1017,6 +985,7 @@ class MusicgenTester:
self.num_codebooks = num_codebooks
self.num_filters = num_filters
self.codebook_size = codebook_size
self.audio_channels = audio_channels
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -1052,6 +1021,7 @@ class MusicgenTester:
bos_token_id=self.bos_token_id,
num_codebooks=self.num_codebooks,
tie_word_embeddings=False,
audio_channels=self.audio_channels,
)
config = MusicgenConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config)
return config
@ -1415,170 +1385,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
# take max batch_size
sequence_length = input_ids.shape[-1]
input_ids = input_ids[:batch_size, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes)
def _greedy_generate(
self,
model,
input_ids,
attention_mask,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
max_new_tokens=self.max_new_tokens,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**model_kwargs,
)
return output_generate
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
# different modalities -> different shapes)
def _sample_generate(
self,
model,
input_ids,
attention_mask,
num_return_sequences,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_new_tokens=self.max_new_tokens,
num_return_sequences=num_return_sequences,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**model_kwargs,
)
return output_generate
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes:
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
# check `generate()` and `sample()` are equal
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
num_return_sequences=1,
)
self.assertIsInstance(output_generate, torch.Tensor)
def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
num_return_sequences=3,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_generate_without_input_ids(self):
config, _, _ = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
self.skipTest(reason="bos_token_id is None")
for model_class in self.greedy_sample_model_classes:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
)
self.assertIsNotNone(output_ids_generate)
@require_torch_fp16
@require_torch_accelerator # not all operations are supported in fp16 on CPU
def test_generate_fp16(self):
@ -1595,24 +1405,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
original_audio_channels = self.model_tester.audio_channels
self.model_tester.audio_channels = 2
super().test_greedy_generate_dict_outputs()
self.model_tester.audio_channels = original_audio_channels
@unittest.skip(
reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"

View File

@ -61,9 +61,6 @@ if is_torch_available():
MusicgenMelodyModel,
set_seed,
)
from transformers.generation import (
GenerateDecoderOnlyOutput,
)
if is_torchaudio_available():
from transformers import MusicgenMelodyProcessor
@ -124,6 +121,7 @@ class MusicgenMelodyDecoderTester:
bos_token_id=99,
num_codebooks=4,
conditional_seq_length=4,
audio_channels=1,
):
self.parent = parent
self.batch_size = batch_size
@ -143,6 +141,7 @@ class MusicgenMelodyDecoderTester:
self.num_codebooks = num_codebooks
self.conditional_seq_length = conditional_seq_length
self.encoder_seq_length = conditional_seq_length + seq_length
self.audio_channels = audio_channels
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
@ -168,6 +167,7 @@ class MusicgenMelodyDecoderTester:
bos_token_id=self.bos_token_id,
num_codebooks=self.num_codebooks,
tie_word_embeddings=False,
audio_channels=self.audio_channels,
)
return config
@ -285,46 +285,15 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
def test_tied_weights_keys(self):
pass
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
_ = inputs_dict.pop("attention_mask", None)
inputs_dict = {
k: v[:batch_size, ...]
for k, v in inputs_dict.items()
if "head_mask" not in k and isinstance(v, torch.Tensor)
}
# take max batch_size
sequence_length = input_ids.shape[-1]
input_ids = input_ids[: batch_size * config.num_codebooks, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask, inputs_dict
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
inputs_dict={},
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
original_audio_channels = self.model_tester.audio_channels
self.model_tester.audio_channels = 2
super().test_greedy_generate_dict_outputs()
self.model_tester.audio_channels = original_audio_channels
@require_flash_attn
@require_torch_gpu
@ -996,6 +965,7 @@ class MusicgenMelodyTester:
codebook_size=128,
conditional_seq_length=3,
chroma_length=24,
audio_channels=1,
):
self.parent = parent
self.batch_size = batch_size
@ -1018,6 +988,7 @@ class MusicgenMelodyTester:
self.conditional_seq_length = conditional_seq_length
self.chroma_length = chroma_length
self.encoder_seq_length = conditional_seq_length + seq_length
self.audio_channels = audio_channels
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.conditional_seq_length], self.vocab_size)
@ -1053,6 +1024,7 @@ class MusicgenMelodyTester:
bos_token_id=self.bos_token_id,
num_codebooks=self.num_codebooks,
tie_word_embeddings=False,
audio_channels=self.audio_channels,
)
config = MusicgenMelodyConfig.from_sub_models_config(
text_encoder_config, audio_encoder_config, decoder_config, chroma_length=self.chroma_length
@ -1399,170 +1371,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
# take max batch_size
sequence_length = input_ids.shape[-1]
input_ids = input_ids[:batch_size, :]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
# different modalities -> different shapes)
def _greedy_generate(
self,
model,
input_ids,
attention_mask,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=False,
num_beams=1,
max_new_tokens=self.max_new_tokens,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**model_kwargs,
)
return output_generate
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
# different modalities -> different shapes)
def _sample_generate(
self,
model,
input_ids,
attention_mask,
num_return_sequences,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
torch.manual_seed(0)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_new_tokens=self.max_new_tokens,
num_return_sequences=num_return_sequences,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
remove_invalid_values=True,
**model_kwargs,
)
return output_generate
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.greedy_sample_model_classes:
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
# check `generate()` and `sample()` are equal
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
num_return_sequences=1,
)
self.assertIsInstance(output_generate, torch.Tensor)
def test_sample_generate_dict_output(self):
for model_class in self.greedy_sample_model_classes:
# disable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
num_return_sequences=3,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
def test_generate_without_input_ids(self):
config, _, _ = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
self.skipTest(reason="bos_token_id is None")
for model_class in self.greedy_sample_model_classes:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
)
self.assertIsNotNone(output_ids_generate)
@require_torch_fp16
@require_torch_accelerator # not all operations are supported in fp16 on CPU
def test_generate_fp16(self):
@ -1579,24 +1391,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
original_audio_channels = self.model_tester.audio_channels
self.model_tester.audio_channels = 2
super().test_greedy_generate_dict_outputs()
self.model_tester.audio_channels = original_audio_channels
@unittest.skip(
reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"

View File

@ -101,7 +101,7 @@ class OlmoModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -111,7 +111,7 @@ class OlmoeModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -110,7 +110,7 @@ class PersimmonModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -151,7 +151,7 @@ class Phi3ModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -116,7 +116,7 @@ class Qwen2ModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -134,7 +134,7 @@ class Qwen2MoeModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -103,7 +103,7 @@ class RecurrentGemmaModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -684,20 +684,15 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
def test_left_padding_compatibility(self):
pass
def _get_input_ids_and_config(self, batch_size=2):
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
# override because overwise we hit max possible seq length for model (4*8=32)
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.pop(self.input_name)
_ = inputs_dict.pop("attention_mask", None)
_ = inputs_dict.pop("decoder_input_ids", None)
_ = inputs_dict.pop("decoder_attention_mask", None)
input_ids = input_ids[:batch_size, :16]
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :16]
config.eos_token_id = None
config.forced_eos_token_id = None
return config, input_ids, attention_mask, inputs_dict
original_sequence_length = self.model_tester.seq_length
self.model_tester.seq_length = 16
test_inputs = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
self.model_tester.seq_length = original_sequence_length
return test_inputs
@require_torch

View File

@ -360,8 +360,6 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (SeamlessM4TForSpeechToText,) if is_torch_available() else ()
input_name = "input_features"
def setUp(self):
self.model_tester = SeamlessM4TModelTester(self, input_modality="speech")
self.config_tester = ConfigTester(self, config_class=SeamlessM4TConfig)
@ -379,26 +377,6 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
model = SeamlessM4TModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name]
# cut to half length & take max batch_size 3
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
if isinstance(config.eos_token_id, int):
config.eos_token_id = [config.eos_token_id]
config.pad_token_id = config.eos_token_id[0]
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.long)[:batch_size, :sequence_length]
return config, input_ids.float(), attention_mask, max_length
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -376,8 +376,6 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
)
all_generative_model_classes = (SeamlessM4Tv2ForSpeechToText,) if is_torch_available() else ()
input_name = "input_features"
def setUp(self):
self.model_tester = SeamlessM4Tv2ModelTester(self, input_modality="speech")
self.config_tester = ConfigTester(self, config_class=SeamlessM4Tv2Config)
@ -395,26 +393,6 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
model = SeamlessM4Tv2Model.from_pretrained(model_name)
self.assertIsNotNone(model)
def _get_input_ids_and_config(self, batch_size=2):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name]
# cut to half length & take max batch_size 3
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
if isinstance(config.eos_token_id, int):
config.eos_token_id = [config.eos_token_id]
config.pad_token_id = config.eos_token_id[0]
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.long)[:batch_size, :sequence_length]
return config, input_ids.float(), attention_mask, max_length
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -282,20 +282,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
test_pruning = False
test_missing_keys = False
input_name = "input_features"
def _get_input_ids_and_config(self, batch_size=2):
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(self)
# `input_ids` is actually `input_features` which is a 3D tensor.
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
# attention mask of the same shape as `input_ids`.
if len(attention_mask.shape) > 2:
sequence_length = input_ids.shape[1]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
return config, input_ids, attention_mask, inputs_dict
def setUp(self):
self.model_tester = Speech2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)
@ -632,46 +618,12 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
def test_generate_without_input_ids(self):
pass
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape[:2]
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
# Attentions
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, batch_size, config, subsampled_seq_length
)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
)
# decoder
self._check_hidden_states_for_generate(
num_sequences_in_output,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
# In this model, the index of `batch_size` and `sequence_length`` in `main_input` is different: they are the
# first two dimensions of the tensor.
main_input = main_input[:, :, 0]
super()._check_outputs(
output, main_input, config, use_cache=use_cache, num_return_sequences=num_return_sequences
)
def _create_and_check_torchscript(self, config, inputs_dict):

View File

@ -177,8 +177,6 @@ class SpeechT5ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
test_headmasking = False
test_resize_embeddings = False
input_name = "input_values"
def setUp(self):
self.model_tester = SpeechT5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
@ -375,8 +373,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_headmasking = False
input_name = "input_values"
def setUp(self):
self.model_tester = SpeechT5ForSpeechToTextTester(self)
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
@ -895,8 +891,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_headmasking = False
input_name = "input_ids"
def setUp(self):
self.model_tester = SpeechT5ForTextToSpeechTester(self)
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
@ -1441,8 +1435,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
test_headmasking = False
test_resize_embeddings = False
input_name = "input_values"
def setUp(self):
self.model_tester = SpeechT5ForSpeechToSpeechTester(self)
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
@ -1854,8 +1846,6 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
has_attentions = False
input_name = "spectrogram"
def setUp(self):
self.model_tester = SpeechT5HifiGanTester(self)
self.config_tester = ConfigTester(self, config_class=SpeechT5HifiGanConfig)

View File

@ -113,7 +113,7 @@ class StableLmModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -107,7 +107,7 @@ class Starcoder2ModelTester:
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:

View File

@ -470,7 +470,7 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string, output_strings_xla)
@slow
def test_greedy_generate(self):
def test_t5_greedy_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
@ -520,7 +520,7 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(expected_output_string_xla, output_strings_xla)
@slow
def test_sample_generate(self):
def test_t5_sample_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")

View File

@ -118,8 +118,6 @@ class UnivNetModelTest(ModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
has_attentions = False
input_name = "input_features"
def setUp(self):
self.model_tester = UnivNetModelTester(self)
self.config_tester = ConfigTester(

View File

@ -167,8 +167,6 @@ class VitsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
test_torchscript = False
has_attentions = False
input_name = "input_ids"
def setUp(self):
self.model_tester = VitsModelTester(self)
self.config_tester = ConfigTester(self, config_class=VitsConfig, hidden_size=37)

View File

@ -395,8 +395,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# `0.5` is for `test_disk_offload` (which also works for `test_model_parallelism`)
model_split_percents = [0.5, 0.8, 0.9]
input_name = "input_features"
# TODO: Fix the failed tests
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
@ -868,48 +866,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_without_input_ids(self):
pass
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, mel, seq_length = input_ids.shape
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
# Attentions
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, batch_size, config, subsampled_seq_length
)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
)
# decoder
self._check_hidden_states_for_generate(
num_sequences_in_output,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@ -3511,8 +3467,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
test_pruning = False
test_missing_keys = False
input_name = "input_features"
def setUp(self):
self.model_tester = WhisperEncoderModelTester(self)
self.config_tester = ConfigTester(self, config_class=WhisperConfig)