mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Generate tests: modality-agnostic input preparation (#33685)
This commit is contained in:
parent
f2bf4fcf3d
commit
d29738f5b4
File diff suppressed because it is too large
Load Diff
@ -283,28 +283,6 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
|
||||
return False
|
||||
|
||||
# overwrite from GenerationTesterMixin to solve problem
|
||||
# with conflicting random seeds
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.attention_type = "original_full"
|
||||
|
||||
input_ids = inputs_dict.pop(self.input_name)
|
||||
_ = inputs_dict.pop("attention_mask", None)
|
||||
_ = inputs_dict.pop("decoder_input_ids", None)
|
||||
_ = inputs_dict.pop("decoder_attention_mask", None)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:batch_size, :sequence_length]
|
||||
attention_mask = attention_mask[:batch_size, :sequence_length]
|
||||
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BigBirdPegasusModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BigBirdPegasusConfig)
|
||||
@ -485,6 +463,13 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
def test_load_save_without_tied_weights(self):
|
||||
pass
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
# overwritten to temporarily switch the attention type to `original_full`
|
||||
original_self_attention_type = self.model_tester.attention_type
|
||||
self.model_tester.attention_type = "original_full"
|
||||
super().test_generate_with_head_masking()
|
||||
self.model_tester.attention_type = original_self_attention_type
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
@ -116,7 +116,7 @@ class ChameleonModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
|
@ -95,7 +95,7 @@ class CohereModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -123,7 +123,6 @@ class DacModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
pipeline_model_mapping = {"feature-extraction": DacModel} if is_torch_available() else {}
|
||||
input_name = "input_values"
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
# model does not have attention and does not support returning hidden states
|
||||
|
@ -141,7 +141,6 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
pipeline_model_mapping = {"feature-extraction": EncodecModel} if is_torch_available() else {}
|
||||
input_name = "input_values"
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
# model does not have attention and does not support returning hidden states
|
||||
|
@ -119,7 +119,7 @@ class GemmaModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -106,7 +106,7 @@ class GraniteModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -105,7 +105,7 @@ class GraniteMoeModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -338,13 +338,11 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.check_global_attention(*config_and_inputs)
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(
|
||||
self, batch_size=batch_size
|
||||
)
|
||||
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
|
||||
config, inputs_dict = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
|
||||
# LED computes attention scores based on mask indices if `is_global`
|
||||
inputs_dict.pop("global_attention_mask")
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
return config, inputs_dict
|
||||
|
||||
# LEDForSequenceClassification does not support inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
|
@ -112,7 +112,7 @@ class LlamaModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -170,7 +170,6 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
test_torchscript = False
|
||||
input_name = "input_values"
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
# model does support returning hidden states
|
||||
|
@ -112,7 +112,7 @@ class MistralModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -108,7 +108,7 @@ class MixtralModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -60,10 +60,6 @@ if is_torch_available():
|
||||
MusicgenModel,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.generation import (
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
)
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
@ -124,6 +120,7 @@ class MusicgenDecoderTester:
|
||||
pad_token_id=99,
|
||||
bos_token_id=99,
|
||||
num_codebooks=4,
|
||||
audio_channels=1,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -141,6 +138,7 @@ class MusicgenDecoderTester:
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.num_codebooks = num_codebooks
|
||||
self.audio_channels = audio_channels
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
|
||||
@ -166,6 +164,7 @@ class MusicgenDecoderTester:
|
||||
bos_token_id=self.bos_token_id,
|
||||
num_codebooks=self.num_codebooks,
|
||||
tie_word_embeddings=False,
|
||||
audio_channels=self.audio_channels,
|
||||
)
|
||||
return config
|
||||
|
||||
@ -282,47 +281,15 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
_ = inputs_dict.pop("attention_mask", None)
|
||||
inputs_dict = {
|
||||
k: v[:batch_size, ...]
|
||||
for k, v in inputs_dict.items()
|
||||
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
||||
}
|
||||
|
||||
# take max batch_size
|
||||
sequence_length = input_ids.shape[-1]
|
||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
inputs_dict={},
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
original_audio_channels = self.model_tester.audio_channels
|
||||
self.model_tester.audio_channels = 2
|
||||
super().test_greedy_generate_dict_outputs()
|
||||
self.model_tester.audio_channels = original_audio_channels
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@ -998,6 +965,7 @@ class MusicgenTester:
|
||||
num_codebooks=4,
|
||||
num_filters=4,
|
||||
codebook_size=128,
|
||||
audio_channels=1,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -1017,6 +985,7 @@ class MusicgenTester:
|
||||
self.num_codebooks = num_codebooks
|
||||
self.num_filters = num_filters
|
||||
self.codebook_size = codebook_size
|
||||
self.audio_channels = audio_channels
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -1052,6 +1021,7 @@ class MusicgenTester:
|
||||
bos_token_id=self.bos_token_id,
|
||||
num_codebooks=self.num_codebooks,
|
||||
tie_word_embeddings=False,
|
||||
audio_channels=self.audio_channels,
|
||||
)
|
||||
config = MusicgenConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config)
|
||||
return config
|
||||
@ -1415,170 +1385,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
lm_heads = model.get_output_embeddings()
|
||||
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
# take max batch_size
|
||||
sequence_length = input_ids.shape[-1]
|
||||
input_ids = input_ids[:batch_size, :]
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
|
||||
# different modalities -> different shapes)
|
||||
def _greedy_generate(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
output_scores=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
|
||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
|
||||
# different modalities -> different shapes)
|
||||
def _sample_generate(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
num_return_sequences,
|
||||
output_scores=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
num_return_sequences=num_return_sequences,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# check `generate()` and `sample()` are equal
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
num_return_sequences=1,
|
||||
)
|
||||
self.assertIsInstance(output_generate, torch.Tensor)
|
||||
|
||||
def test_sample_generate_dict_output(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
num_return_sequences=3,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _ = self._get_input_ids_and_config()
|
||||
|
||||
# if no bos token id => cannot generate from None
|
||||
if config.bos_token_id is None:
|
||||
self.skipTest(reason="bos_token_id is None")
|
||||
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||
)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
@require_torch_fp16
|
||||
@require_torch_accelerator # not all operations are supported in fp16 on CPU
|
||||
def test_generate_fp16(self):
|
||||
@ -1595,24 +1405,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
original_audio_channels = self.model_tester.audio_channels
|
||||
self.model_tester.audio_channels = 2
|
||||
super().test_greedy_generate_dict_outputs()
|
||||
self.model_tester.audio_channels = original_audio_channels
|
||||
|
||||
@unittest.skip(
|
||||
reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"
|
||||
|
@ -61,9 +61,6 @@ if is_torch_available():
|
||||
MusicgenMelodyModel,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.generation import (
|
||||
GenerateDecoderOnlyOutput,
|
||||
)
|
||||
|
||||
if is_torchaudio_available():
|
||||
from transformers import MusicgenMelodyProcessor
|
||||
@ -124,6 +121,7 @@ class MusicgenMelodyDecoderTester:
|
||||
bos_token_id=99,
|
||||
num_codebooks=4,
|
||||
conditional_seq_length=4,
|
||||
audio_channels=1,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -143,6 +141,7 @@ class MusicgenMelodyDecoderTester:
|
||||
self.num_codebooks = num_codebooks
|
||||
self.conditional_seq_length = conditional_seq_length
|
||||
self.encoder_seq_length = conditional_seq_length + seq_length
|
||||
self.audio_channels = audio_channels
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
|
||||
@ -168,6 +167,7 @@ class MusicgenMelodyDecoderTester:
|
||||
bos_token_id=self.bos_token_id,
|
||||
num_codebooks=self.num_codebooks,
|
||||
tie_word_embeddings=False,
|
||||
audio_channels=self.audio_channels,
|
||||
)
|
||||
return config
|
||||
|
||||
@ -285,46 +285,15 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
_ = inputs_dict.pop("attention_mask", None)
|
||||
inputs_dict = {
|
||||
k: v[:batch_size, ...]
|
||||
for k, v in inputs_dict.items()
|
||||
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
||||
}
|
||||
|
||||
# take max batch_size
|
||||
sequence_length = input_ids.shape[-1]
|
||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
inputs_dict={},
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
original_audio_channels = self.model_tester.audio_channels
|
||||
self.model_tester.audio_channels = 2
|
||||
super().test_greedy_generate_dict_outputs()
|
||||
self.model_tester.audio_channels = original_audio_channels
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@ -996,6 +965,7 @@ class MusicgenMelodyTester:
|
||||
codebook_size=128,
|
||||
conditional_seq_length=3,
|
||||
chroma_length=24,
|
||||
audio_channels=1,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -1018,6 +988,7 @@ class MusicgenMelodyTester:
|
||||
self.conditional_seq_length = conditional_seq_length
|
||||
self.chroma_length = chroma_length
|
||||
self.encoder_seq_length = conditional_seq_length + seq_length
|
||||
self.audio_channels = audio_channels
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.conditional_seq_length], self.vocab_size)
|
||||
@ -1053,6 +1024,7 @@ class MusicgenMelodyTester:
|
||||
bos_token_id=self.bos_token_id,
|
||||
num_codebooks=self.num_codebooks,
|
||||
tie_word_embeddings=False,
|
||||
audio_channels=self.audio_channels,
|
||||
)
|
||||
config = MusicgenMelodyConfig.from_sub_models_config(
|
||||
text_encoder_config, audio_encoder_config, decoder_config, chroma_length=self.chroma_length
|
||||
@ -1399,170 +1371,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
lm_heads = model.get_output_embeddings()
|
||||
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
# take max batch_size
|
||||
sequence_length = input_ids.shape[-1]
|
||||
input_ids = input_ids[:batch_size, :]
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
||||
# different modalities -> different shapes)
|
||||
def _greedy_generate(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
output_scores=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
|
||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
||||
# different modalities -> different shapes)
|
||||
def _sample_generate(
|
||||
self,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
num_return_sequences,
|
||||
output_scores=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
num_return_sequences=num_return_sequences,
|
||||
output_scores=output_scores,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
remove_invalid_values=True,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output_generate
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# check `generate()` and `sample()` are equal
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
num_return_sequences=1,
|
||||
)
|
||||
self.assertIsInstance(output_generate, torch.Tensor)
|
||||
|
||||
def test_sample_generate_dict_output(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
# disable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
num_return_sequences=3,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _ = self._get_input_ids_and_config()
|
||||
|
||||
# if no bos token id => cannot generate from None
|
||||
if config.bos_token_id is None:
|
||||
self.skipTest(reason="bos_token_id is None")
|
||||
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||
)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
@require_torch_fp16
|
||||
@require_torch_accelerator # not all operations are supported in fp16 on CPU
|
||||
def test_generate_fp16(self):
|
||||
@ -1579,24 +1391,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
)
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.audio_channels = 2
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
model=model,
|
||||
input_ids=input_ids.to(torch_device),
|
||||
attention_mask=attention_mask.to(torch_device),
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
self.assertNotIn(config.pad_token_id, output_generate)
|
||||
original_audio_channels = self.model_tester.audio_channels
|
||||
self.model_tester.audio_channels = 2
|
||||
super().test_greedy_generate_dict_outputs()
|
||||
self.model_tester.audio_channels = original_audio_channels
|
||||
|
||||
@unittest.skip(
|
||||
reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
|
||||
|
@ -101,7 +101,7 @@ class OlmoModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -111,7 +111,7 @@ class OlmoeModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -110,7 +110,7 @@ class PersimmonModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -151,7 +151,7 @@ class Phi3ModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -116,7 +116,7 @@ class Qwen2ModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -134,7 +134,7 @@ class Qwen2MoeModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -103,7 +103,7 @@ class RecurrentGemmaModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -684,20 +684,15 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
||||
def test_left_padding_compatibility(self):
|
||||
pass
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
|
||||
# override because overwise we hit max possible seq length for model (4*8=32)
|
||||
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
|
||||
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.pop(self.input_name)
|
||||
_ = inputs_dict.pop("attention_mask", None)
|
||||
_ = inputs_dict.pop("decoder_input_ids", None)
|
||||
_ = inputs_dict.pop("decoder_attention_mask", None)
|
||||
input_ids = input_ids[:batch_size, :16]
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :16]
|
||||
config.eos_token_id = None
|
||||
config.forced_eos_token_id = None
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
original_sequence_length = self.model_tester.seq_length
|
||||
self.model_tester.seq_length = 16
|
||||
test_inputs = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
|
||||
self.model_tester.seq_length = original_sequence_length
|
||||
return test_inputs
|
||||
|
||||
|
||||
@require_torch
|
||||
|
@ -360,8 +360,6 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (SeamlessM4TForSpeechToText,) if is_torch_available() else ()
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SeamlessM4TModelTester(self, input_modality="speech")
|
||||
self.config_tester = ConfigTester(self, config_class=SeamlessM4TConfig)
|
||||
@ -379,26 +377,6 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = SeamlessM4TModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict[self.input_name]
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:batch_size, :sequence_length]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
if isinstance(config.eos_token_id, int):
|
||||
config.eos_token_id = [config.eos_token_id]
|
||||
config.pad_token_id = config.eos_token_id[0]
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.long)[:batch_size, :sequence_length]
|
||||
|
||||
return config, input_ids.float(), attention_mask, max_length
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -376,8 +376,6 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
||||
)
|
||||
all_generative_model_classes = (SeamlessM4Tv2ForSpeechToText,) if is_torch_available() else ()
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SeamlessM4Tv2ModelTester(self, input_modality="speech")
|
||||
self.config_tester = ConfigTester(self, config_class=SeamlessM4Tv2Config)
|
||||
@ -395,26 +393,6 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
||||
model = SeamlessM4Tv2Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict[self.input_name]
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:batch_size, :sequence_length]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
if isinstance(config.eos_token_id, int):
|
||||
config.eos_token_id = [config.eos_token_id]
|
||||
config.pad_token_id = config.eos_token_id[0]
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.long)[:batch_size, :sequence_length]
|
||||
|
||||
return config, input_ids.float(), attention_mask, max_length
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -282,20 +282,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(self)
|
||||
|
||||
# `input_ids` is actually `input_features` which is a 3D tensor.
|
||||
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
|
||||
# attention mask of the same shape as `input_ids`.
|
||||
if len(attention_mask.shape) > 2:
|
||||
sequence_length = input_ids.shape[1]
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
|
||||
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Speech2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)
|
||||
@ -632,46 +618,12 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
)
|
||||
|
||||
# scores
|
||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
||||
|
||||
# Attentions
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(
|
||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
# encoder
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
|
||||
# decoder
|
||||
self._check_hidden_states_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_hidden_states,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
|
||||
# In this model, the index of `batch_size` and `sequence_length`` in `main_input` is different: they are the
|
||||
# first two dimensions of the tensor.
|
||||
main_input = main_input[:, :, 0]
|
||||
super()._check_outputs(
|
||||
output, main_input, config, use_cache=use_cache, num_return_sequences=num_return_sequences
|
||||
)
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
|
@ -177,8 +177,6 @@ class SpeechT5ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
input_name = "input_values"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SpeechT5ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
||||
@ -375,8 +373,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
|
||||
input_name = "input_values"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SpeechT5ForSpeechToTextTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
||||
@ -895,8 +891,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
|
||||
input_name = "input_ids"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SpeechT5ForTextToSpeechTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
||||
@ -1441,8 +1435,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
input_name = "input_values"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SpeechT5ForSpeechToSpeechTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
||||
@ -1854,8 +1846,6 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
has_attentions = False
|
||||
|
||||
input_name = "spectrogram"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = SpeechT5HifiGanTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=SpeechT5HifiGanConfig)
|
||||
|
@ -113,7 +113,7 @@ class StableLmModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -107,7 +107,7 @@ class Starcoder2ModelTester:
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
@ -470,7 +470,7 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(expected_output_string, output_strings_xla)
|
||||
|
||||
@slow
|
||||
def test_greedy_generate(self):
|
||||
def test_t5_greedy_generate(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
|
||||
@ -520,7 +520,7 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(expected_output_string_xla, output_strings_xla)
|
||||
|
||||
@slow
|
||||
def test_sample_generate(self):
|
||||
def test_t5_sample_generate(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
|
||||
|
@ -118,8 +118,6 @@ class UnivNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
has_attentions = False
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = UnivNetModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
|
@ -167,8 +167,6 @@ class VitsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
test_torchscript = False
|
||||
has_attentions = False
|
||||
|
||||
input_name = "input_ids"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = VitsModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=VitsConfig, hidden_size=37)
|
||||
|
@ -395,8 +395,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# `0.5` is for `test_disk_offload` (which also works for `test_model_parallelism`)
|
||||
model_split_percents = [0.5, 0.8, 0.9]
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
# TODO: Fix the failed tests
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
@ -868,48 +866,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, mel, seq_length = input_ids.shape
|
||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
)
|
||||
|
||||
# scores
|
||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
||||
|
||||
# Attentions
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(
|
||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
# encoder
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
|
||||
# decoder
|
||||
self._check_hidden_states_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_hidden_states,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@ -3511,8 +3467,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = WhisperEncoderModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||
|
Loading…
Reference in New Issue
Block a user