mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 13:50:13 +06:00
Generate tests: modality-agnostic input preparation (#33685)
This commit is contained in:
parent
f2bf4fcf3d
commit
d29738f5b4
File diff suppressed because it is too large
Load Diff
@ -283,28 +283,6 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# overwrite from GenerationTesterMixin to solve problem
|
|
||||||
# with conflicting random seeds
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.attention_type = "original_full"
|
|
||||||
|
|
||||||
input_ids = inputs_dict.pop(self.input_name)
|
|
||||||
_ = inputs_dict.pop("attention_mask", None)
|
|
||||||
_ = inputs_dict.pop("decoder_input_ids", None)
|
|
||||||
_ = inputs_dict.pop("decoder_attention_mask", None)
|
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
|
||||||
|
|
||||||
# cut to half length & take max batch_size 3
|
|
||||||
sequence_length = input_ids.shape[-1] // 2
|
|
||||||
input_ids = input_ids[:batch_size, :sequence_length]
|
|
||||||
attention_mask = attention_mask[:batch_size, :sequence_length]
|
|
||||||
|
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
|
||||||
config.pad_token_id = config.eos_token_id
|
|
||||||
return config, input_ids, attention_mask, inputs_dict
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = BigBirdPegasusModelTester(self)
|
self.model_tester = BigBirdPegasusModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=BigBirdPegasusConfig)
|
self.config_tester = ConfigTester(self, config_class=BigBirdPegasusConfig)
|
||||||
@ -485,6 +463,13 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
def test_load_save_without_tied_weights(self):
|
def test_load_save_without_tied_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_generate_with_head_masking(self):
|
||||||
|
# overwritten to temporarily switch the attention type to `original_full`
|
||||||
|
original_self_attention_type = self.model_tester.attention_type
|
||||||
|
self.model_tester.attention_type = "original_full"
|
||||||
|
super().test_generate_with_head_masking()
|
||||||
|
self.model_tester.attention_type = original_self_attention_type
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
@ -116,7 +116,7 @@ class ChameleonModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
sequence_labels = None
|
sequence_labels = None
|
||||||
token_labels = None
|
token_labels = None
|
||||||
|
@ -95,7 +95,7 @@ class CohereModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -123,7 +123,6 @@ class DacModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
pipeline_model_mapping = {"feature-extraction": DacModel} if is_torch_available() else {}
|
pipeline_model_mapping = {"feature-extraction": DacModel} if is_torch_available() else {}
|
||||||
input_name = "input_values"
|
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
# model does not have attention and does not support returning hidden states
|
# model does not have attention and does not support returning hidden states
|
||||||
|
@ -141,7 +141,6 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
pipeline_model_mapping = {"feature-extraction": EncodecModel} if is_torch_available() else {}
|
pipeline_model_mapping = {"feature-extraction": EncodecModel} if is_torch_available() else {}
|
||||||
input_name = "input_values"
|
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
# model does not have attention and does not support returning hidden states
|
# model does not have attention and does not support returning hidden states
|
||||||
|
@ -119,7 +119,7 @@ class GemmaModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -106,7 +106,7 @@ class GraniteModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -105,7 +105,7 @@ class GraniteMoeModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -338,13 +338,11 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
self.model_tester.check_global_attention(*config_and_inputs)
|
self.model_tester.check_global_attention(*config_and_inputs)
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
|
||||||
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(
|
config, inputs_dict = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
|
||||||
self, batch_size=batch_size
|
|
||||||
)
|
|
||||||
# LED computes attention scores based on mask indices if `is_global`
|
# LED computes attention scores based on mask indices if `is_global`
|
||||||
inputs_dict.pop("global_attention_mask")
|
inputs_dict.pop("global_attention_mask")
|
||||||
return config, input_ids, attention_mask, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
# LEDForSequenceClassification does not support inputs_embeds
|
# LEDForSequenceClassification does not support inputs_embeds
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
|
@ -112,7 +112,7 @@ class LlamaModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -170,7 +170,6 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
input_name = "input_values"
|
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
# model does support returning hidden states
|
# model does support returning hidden states
|
||||||
|
@ -112,7 +112,7 @@ class MistralModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -108,7 +108,7 @@ class MixtralModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -60,10 +60,6 @@ if is_torch_available():
|
|||||||
MusicgenModel,
|
MusicgenModel,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.generation import (
|
|
||||||
GenerateDecoderOnlyOutput,
|
|
||||||
GenerateEncoderDecoderOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _config_zero_init(config):
|
def _config_zero_init(config):
|
||||||
@ -124,6 +120,7 @@ class MusicgenDecoderTester:
|
|||||||
pad_token_id=99,
|
pad_token_id=99,
|
||||||
bos_token_id=99,
|
bos_token_id=99,
|
||||||
num_codebooks=4,
|
num_codebooks=4,
|
||||||
|
audio_channels=1,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -141,6 +138,7 @@ class MusicgenDecoderTester:
|
|||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.num_codebooks = num_codebooks
|
self.num_codebooks = num_codebooks
|
||||||
|
self.audio_channels = audio_channels
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
|
||||||
@ -166,6 +164,7 @@ class MusicgenDecoderTester:
|
|||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
num_codebooks=self.num_codebooks,
|
num_codebooks=self.num_codebooks,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
|
audio_channels=self.audio_channels,
|
||||||
)
|
)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@ -282,47 +281,15 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
def test_tied_weights_keys(self):
|
def test_tied_weights_keys(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
input_ids = inputs_dict["input_ids"]
|
|
||||||
|
|
||||||
_ = inputs_dict.pop("attention_mask", None)
|
|
||||||
inputs_dict = {
|
|
||||||
k: v[:batch_size, ...]
|
|
||||||
for k, v in inputs_dict.items()
|
|
||||||
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
|
||||||
}
|
|
||||||
|
|
||||||
# take max batch_size
|
|
||||||
sequence_length = input_ids.shape[-1]
|
|
||||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
|
||||||
|
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
|
||||||
return config, input_ids, attention_mask, inputs_dict
|
|
||||||
|
|
||||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||||
logits_processor_kwargs = {}
|
logits_processor_kwargs = {}
|
||||||
return logits_processor_kwargs
|
return logits_processor_kwargs
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
original_audio_channels = self.model_tester.audio_channels
|
||||||
config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config()
|
self.model_tester.audio_channels = 2
|
||||||
config.audio_channels = 2
|
super().test_greedy_generate_dict_outputs()
|
||||||
model = model_class(config).to(torch_device).eval()
|
self.model_tester.audio_channels = original_audio_channels
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
inputs_dict={},
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@ -998,6 +965,7 @@ class MusicgenTester:
|
|||||||
num_codebooks=4,
|
num_codebooks=4,
|
||||||
num_filters=4,
|
num_filters=4,
|
||||||
codebook_size=128,
|
codebook_size=128,
|
||||||
|
audio_channels=1,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -1017,6 +985,7 @@ class MusicgenTester:
|
|||||||
self.num_codebooks = num_codebooks
|
self.num_codebooks = num_codebooks
|
||||||
self.num_filters = num_filters
|
self.num_filters = num_filters
|
||||||
self.codebook_size = codebook_size
|
self.codebook_size = codebook_size
|
||||||
|
self.audio_channels = audio_channels
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
@ -1052,6 +1021,7 @@ class MusicgenTester:
|
|||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
num_codebooks=self.num_codebooks,
|
num_codebooks=self.num_codebooks,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
|
audio_channels=self.audio_channels,
|
||||||
)
|
)
|
||||||
config = MusicgenConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config)
|
config = MusicgenConfig.from_sub_models_config(text_encoder_config, audio_encoder_config, decoder_config)
|
||||||
return config
|
return config
|
||||||
@ -1415,170 +1385,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
lm_heads = model.get_output_embeddings()
|
lm_heads = model.get_output_embeddings()
|
||||||
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
input_ids = inputs_dict["input_ids"]
|
|
||||||
|
|
||||||
# take max batch_size
|
|
||||||
sequence_length = input_ids.shape[-1]
|
|
||||||
input_ids = input_ids[:batch_size, :]
|
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
|
||||||
|
|
||||||
return config, input_ids, attention_mask
|
|
||||||
|
|
||||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
|
|
||||||
# different modalities -> different shapes)
|
|
||||||
def _greedy_generate(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
output_scores=False,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
return_dict_in_generate=False,
|
|
||||||
):
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
|
||||||
output_generate = model.generate(
|
|
||||||
input_ids,
|
|
||||||
do_sample=False,
|
|
||||||
num_beams=1,
|
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_scores=output_scores,
|
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
|
||||||
remove_invalid_values=True,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output_generate
|
|
||||||
|
|
||||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
|
|
||||||
# different modalities -> different shapes)
|
|
||||||
def _sample_generate(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
num_return_sequences,
|
|
||||||
output_scores=False,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
return_dict_in_generate=False,
|
|
||||||
):
|
|
||||||
torch.manual_seed(0)
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
|
||||||
output_generate = model.generate(
|
|
||||||
input_ids,
|
|
||||||
do_sample=True,
|
|
||||||
num_beams=1,
|
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
num_return_sequences=num_return_sequences,
|
|
||||||
output_scores=output_scores,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
|
||||||
remove_invalid_values=True,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output_generate
|
|
||||||
|
|
||||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||||
logits_processor_kwargs = {}
|
logits_processor_kwargs = {}
|
||||||
return logits_processor_kwargs
|
return logits_processor_kwargs
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# disable cache
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
|
||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# enable cache
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
|
||||||
|
|
||||||
def test_sample_generate(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
|
|
||||||
# check `generate()` and `sample()` are equal
|
|
||||||
output_generate = self._sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
num_return_sequences=1,
|
|
||||||
)
|
|
||||||
self.assertIsInstance(output_generate, torch.Tensor)
|
|
||||||
|
|
||||||
def test_sample_generate_dict_output(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# disable cache
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
|
|
||||||
output_generate = self._sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
num_return_sequences=3,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
|
||||||
|
|
||||||
def test_generate_without_input_ids(self):
|
|
||||||
config, _, _ = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
# if no bos token id => cannot generate from None
|
|
||||||
if config.bos_token_id is None:
|
|
||||||
self.skipTest(reason="bos_token_id is None")
|
|
||||||
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
model = model_class(config).to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
output_ids_generate = model.generate(
|
|
||||||
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(output_ids_generate)
|
|
||||||
|
|
||||||
@require_torch_fp16
|
@require_torch_fp16
|
||||||
@require_torch_accelerator # not all operations are supported in fp16 on CPU
|
@require_torch_accelerator # not all operations are supported in fp16 on CPU
|
||||||
def test_generate_fp16(self):
|
def test_generate_fp16(self):
|
||||||
@ -1595,24 +1405,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
original_audio_channels = self.model_tester.audio_channels
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
self.model_tester.audio_channels = 2
|
||||||
config.audio_channels = 2
|
super().test_greedy_generate_dict_outputs()
|
||||||
|
self.model_tester.audio_channels = original_audio_channels
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
|
||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"
|
reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"
|
||||||
|
@ -61,9 +61,6 @@ if is_torch_available():
|
|||||||
MusicgenMelodyModel,
|
MusicgenMelodyModel,
|
||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.generation import (
|
|
||||||
GenerateDecoderOnlyOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_torchaudio_available():
|
if is_torchaudio_available():
|
||||||
from transformers import MusicgenMelodyProcessor
|
from transformers import MusicgenMelodyProcessor
|
||||||
@ -124,6 +121,7 @@ class MusicgenMelodyDecoderTester:
|
|||||||
bos_token_id=99,
|
bos_token_id=99,
|
||||||
num_codebooks=4,
|
num_codebooks=4,
|
||||||
conditional_seq_length=4,
|
conditional_seq_length=4,
|
||||||
|
audio_channels=1,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -143,6 +141,7 @@ class MusicgenMelodyDecoderTester:
|
|||||||
self.num_codebooks = num_codebooks
|
self.num_codebooks = num_codebooks
|
||||||
self.conditional_seq_length = conditional_seq_length
|
self.conditional_seq_length = conditional_seq_length
|
||||||
self.encoder_seq_length = conditional_seq_length + seq_length
|
self.encoder_seq_length = conditional_seq_length + seq_length
|
||||||
|
self.audio_channels = audio_channels
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size * self.num_codebooks, self.seq_length], self.vocab_size)
|
||||||
@ -168,6 +167,7 @@ class MusicgenMelodyDecoderTester:
|
|||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
num_codebooks=self.num_codebooks,
|
num_codebooks=self.num_codebooks,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
|
audio_channels=self.audio_channels,
|
||||||
)
|
)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@ -285,46 +285,15 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
def test_tied_weights_keys(self):
|
def test_tied_weights_keys(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
input_ids = inputs_dict["input_ids"]
|
|
||||||
|
|
||||||
_ = inputs_dict.pop("attention_mask", None)
|
|
||||||
inputs_dict = {
|
|
||||||
k: v[:batch_size, ...]
|
|
||||||
for k, v in inputs_dict.items()
|
|
||||||
if "head_mask" not in k and isinstance(v, torch.Tensor)
|
|
||||||
}
|
|
||||||
|
|
||||||
# take max batch_size
|
|
||||||
sequence_length = input_ids.shape[-1]
|
|
||||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
|
||||||
|
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
|
||||||
return config, input_ids, attention_mask, inputs_dict
|
|
||||||
|
|
||||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||||
logits_processor_kwargs = {}
|
logits_processor_kwargs = {}
|
||||||
return logits_processor_kwargs
|
return logits_processor_kwargs
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
original_audio_channels = self.model_tester.audio_channels
|
||||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
self.model_tester.audio_channels = 2
|
||||||
config.audio_channels = 2
|
super().test_greedy_generate_dict_outputs()
|
||||||
model = model_class(config).to(torch_device).eval()
|
self.model_tester.audio_channels = original_audio_channels
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
inputs_dict={},
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@ -996,6 +965,7 @@ class MusicgenMelodyTester:
|
|||||||
codebook_size=128,
|
codebook_size=128,
|
||||||
conditional_seq_length=3,
|
conditional_seq_length=3,
|
||||||
chroma_length=24,
|
chroma_length=24,
|
||||||
|
audio_channels=1,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -1018,6 +988,7 @@ class MusicgenMelodyTester:
|
|||||||
self.conditional_seq_length = conditional_seq_length
|
self.conditional_seq_length = conditional_seq_length
|
||||||
self.chroma_length = chroma_length
|
self.chroma_length = chroma_length
|
||||||
self.encoder_seq_length = conditional_seq_length + seq_length
|
self.encoder_seq_length = conditional_seq_length + seq_length
|
||||||
|
self.audio_channels = audio_channels
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.conditional_seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.conditional_seq_length], self.vocab_size)
|
||||||
@ -1053,6 +1024,7 @@ class MusicgenMelodyTester:
|
|||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
num_codebooks=self.num_codebooks,
|
num_codebooks=self.num_codebooks,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
|
audio_channels=self.audio_channels,
|
||||||
)
|
)
|
||||||
config = MusicgenMelodyConfig.from_sub_models_config(
|
config = MusicgenMelodyConfig.from_sub_models_config(
|
||||||
text_encoder_config, audio_encoder_config, decoder_config, chroma_length=self.chroma_length
|
text_encoder_config, audio_encoder_config, decoder_config, chroma_length=self.chroma_length
|
||||||
@ -1399,170 +1371,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
lm_heads = model.get_output_embeddings()
|
lm_heads = model.get_output_embeddings()
|
||||||
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
input_ids = inputs_dict["input_ids"]
|
|
||||||
|
|
||||||
# take max batch_size
|
|
||||||
sequence_length = input_ids.shape[-1]
|
|
||||||
input_ids = input_ids[:batch_size, :]
|
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
|
||||||
|
|
||||||
return config, input_ids, attention_mask
|
|
||||||
|
|
||||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
|
||||||
# different modalities -> different shapes)
|
|
||||||
def _greedy_generate(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
output_scores=False,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
return_dict_in_generate=False,
|
|
||||||
):
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
|
||||||
output_generate = model.generate(
|
|
||||||
input_ids,
|
|
||||||
do_sample=False,
|
|
||||||
num_beams=1,
|
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_scores=output_scores,
|
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
|
||||||
remove_invalid_values=True,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output_generate
|
|
||||||
|
|
||||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
|
||||||
# different modalities -> different shapes)
|
|
||||||
def _sample_generate(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
num_return_sequences,
|
|
||||||
output_scores=False,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
return_dict_in_generate=False,
|
|
||||||
):
|
|
||||||
torch.manual_seed(0)
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
|
||||||
output_generate = model.generate(
|
|
||||||
input_ids,
|
|
||||||
do_sample=True,
|
|
||||||
num_beams=1,
|
|
||||||
max_new_tokens=self.max_new_tokens,
|
|
||||||
num_return_sequences=num_return_sequences,
|
|
||||||
output_scores=output_scores,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
|
||||||
remove_invalid_values=True,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output_generate
|
|
||||||
|
|
||||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||||
logits_processor_kwargs = {}
|
logits_processor_kwargs = {}
|
||||||
return logits_processor_kwargs
|
return logits_processor_kwargs
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# disable cache
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# enable cache
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
def test_sample_generate(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
|
|
||||||
# check `generate()` and `sample()` are equal
|
|
||||||
output_generate = self._sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
num_return_sequences=1,
|
|
||||||
)
|
|
||||||
self.assertIsInstance(output_generate, torch.Tensor)
|
|
||||||
|
|
||||||
def test_sample_generate_dict_output(self):
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
# disable cache
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
|
|
||||||
output_generate = self._sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
num_return_sequences=3,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
def test_generate_without_input_ids(self):
|
|
||||||
config, _, _ = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
# if no bos token id => cannot generate from None
|
|
||||||
if config.bos_token_id is None:
|
|
||||||
self.skipTest(reason="bos_token_id is None")
|
|
||||||
|
|
||||||
for model_class in self.greedy_sample_model_classes:
|
|
||||||
model = model_class(config).to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
output_ids_generate = model.generate(
|
|
||||||
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(output_ids_generate)
|
|
||||||
|
|
||||||
@require_torch_fp16
|
@require_torch_fp16
|
||||||
@require_torch_accelerator # not all operations are supported in fp16 on CPU
|
@require_torch_accelerator # not all operations are supported in fp16 on CPU
|
||||||
def test_generate_fp16(self):
|
def test_generate_fp16(self):
|
||||||
@ -1579,24 +1391,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
original_audio_channels = self.model_tester.audio_channels
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
self.model_tester.audio_channels = 2
|
||||||
config.audio_channels = 2
|
super().test_greedy_generate_dict_outputs()
|
||||||
|
self.model_tester.audio_channels = original_audio_channels
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_generate = self._greedy_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids.to(torch_device),
|
|
||||||
attention_mask=attention_mask.to(torch_device),
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
|
||||||
|
|
||||||
self.assertNotIn(config.pad_token_id, output_generate)
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
|
reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
|
||||||
|
@ -101,7 +101,7 @@ class OlmoModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -111,7 +111,7 @@ class OlmoeModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -110,7 +110,7 @@ class PersimmonModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -151,7 +151,7 @@ class Phi3ModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -116,7 +116,7 @@ class Qwen2ModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -134,7 +134,7 @@ class Qwen2MoeModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -103,7 +103,7 @@ class RecurrentGemmaModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -684,20 +684,15 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
|||||||
def test_left_padding_compatibility(self):
|
def test_left_padding_compatibility(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
def prepare_config_and_inputs_for_generate(self, *args, **kwargs):
|
||||||
# override because overwise we hit max possible seq length for model (4*8=32)
|
# override because overwise we hit max possible seq length for model (4*8=32)
|
||||||
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
|
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
|
||||||
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
|
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
original_sequence_length = self.model_tester.seq_length
|
||||||
input_ids = inputs_dict.pop(self.input_name)
|
self.model_tester.seq_length = 16
|
||||||
_ = inputs_dict.pop("attention_mask", None)
|
test_inputs = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
|
||||||
_ = inputs_dict.pop("decoder_input_ids", None)
|
self.model_tester.seq_length = original_sequence_length
|
||||||
_ = inputs_dict.pop("decoder_attention_mask", None)
|
return test_inputs
|
||||||
input_ids = input_ids[:batch_size, :16]
|
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :16]
|
|
||||||
config.eos_token_id = None
|
|
||||||
config.forced_eos_token_id = None
|
|
||||||
return config, input_ids, attention_mask, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@ -360,8 +360,6 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (SeamlessM4TForSpeechToText,) if is_torch_available() else ()
|
all_generative_model_classes = (SeamlessM4TForSpeechToText,) if is_torch_available() else ()
|
||||||
|
|
||||||
input_name = "input_features"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SeamlessM4TModelTester(self, input_modality="speech")
|
self.model_tester = SeamlessM4TModelTester(self, input_modality="speech")
|
||||||
self.config_tester = ConfigTester(self, config_class=SeamlessM4TConfig)
|
self.config_tester = ConfigTester(self, config_class=SeamlessM4TConfig)
|
||||||
@ -379,26 +377,6 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model = SeamlessM4TModel.from_pretrained(model_name)
|
model = SeamlessM4TModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
input_ids = inputs_dict[self.input_name]
|
|
||||||
|
|
||||||
# cut to half length & take max batch_size 3
|
|
||||||
sequence_length = input_ids.shape[-1] // 2
|
|
||||||
input_ids = input_ids[:batch_size, :sequence_length]
|
|
||||||
|
|
||||||
# generate max 3 tokens
|
|
||||||
max_length = input_ids.shape[-1] + 3
|
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
|
||||||
if isinstance(config.eos_token_id, int):
|
|
||||||
config.eos_token_id = [config.eos_token_id]
|
|
||||||
config.pad_token_id = config.eos_token_id[0]
|
|
||||||
|
|
||||||
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.long)[:batch_size, :sequence_length]
|
|
||||||
|
|
||||||
return config, input_ids.float(), attention_mask, max_length
|
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
@ -376,8 +376,6 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (SeamlessM4Tv2ForSpeechToText,) if is_torch_available() else ()
|
all_generative_model_classes = (SeamlessM4Tv2ForSpeechToText,) if is_torch_available() else ()
|
||||||
|
|
||||||
input_name = "input_features"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SeamlessM4Tv2ModelTester(self, input_modality="speech")
|
self.model_tester = SeamlessM4Tv2ModelTester(self, input_modality="speech")
|
||||||
self.config_tester = ConfigTester(self, config_class=SeamlessM4Tv2Config)
|
self.config_tester = ConfigTester(self, config_class=SeamlessM4Tv2Config)
|
||||||
@ -395,26 +393,6 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
|||||||
model = SeamlessM4Tv2Model.from_pretrained(model_name)
|
model = SeamlessM4Tv2Model.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
input_ids = inputs_dict[self.input_name]
|
|
||||||
|
|
||||||
# cut to half length & take max batch_size 3
|
|
||||||
sequence_length = input_ids.shape[-1] // 2
|
|
||||||
input_ids = input_ids[:batch_size, :sequence_length]
|
|
||||||
|
|
||||||
# generate max 3 tokens
|
|
||||||
max_length = input_ids.shape[-1] + 3
|
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
|
||||||
if isinstance(config.eos_token_id, int):
|
|
||||||
config.eos_token_id = [config.eos_token_id]
|
|
||||||
config.pad_token_id = config.eos_token_id[0]
|
|
||||||
|
|
||||||
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.long)[:batch_size, :sequence_length]
|
|
||||||
|
|
||||||
return config, input_ids.float(), attention_mask, max_length
|
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
@ -282,20 +282,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
input_name = "input_features"
|
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
|
||||||
config, input_ids, attention_mask, inputs_dict = GenerationTesterMixin._get_input_ids_and_config(self)
|
|
||||||
|
|
||||||
# `input_ids` is actually `input_features` which is a 3D tensor.
|
|
||||||
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
|
|
||||||
# attention mask of the same shape as `input_ids`.
|
|
||||||
if len(attention_mask.shape) > 2:
|
|
||||||
sequence_length = input_ids.shape[1]
|
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
|
|
||||||
|
|
||||||
return config, input_ids, attention_mask, inputs_dict
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Speech2TextModelTester(self)
|
self.model_tester = Speech2TextModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)
|
self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)
|
||||||
@ -632,46 +618,12 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
|
||||||
batch_size, seq_length = input_ids.shape[:2]
|
# In this model, the index of `batch_size` and `sequence_length`` in `main_input` is different: they are the
|
||||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
# first two dimensions of the tensor.
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
main_input = main_input[:, :, 0]
|
||||||
gen_len = (
|
super()._check_outputs(
|
||||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
output, main_input, config, use_cache=use_cache, num_return_sequences=num_return_sequences
|
||||||
)
|
|
||||||
|
|
||||||
# scores
|
|
||||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
|
||||||
|
|
||||||
# Attentions
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_attention_for_generate(
|
|
||||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
|
||||||
)
|
|
||||||
# decoder
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
num_sequences_in_output,
|
|
||||||
output.decoder_attentions,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hidden States
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_hidden_states_for_generate(
|
|
||||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
num_sequences_in_output,
|
|
||||||
output.decoder_hidden_states,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||||
|
@ -177,8 +177,6 @@ class SpeechT5ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
|
|
||||||
input_name = "input_values"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SpeechT5ModelTester(self)
|
self.model_tester = SpeechT5ModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
||||||
@ -375,8 +373,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
|
|
||||||
input_name = "input_values"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SpeechT5ForSpeechToTextTester(self)
|
self.model_tester = SpeechT5ForSpeechToTextTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
||||||
@ -895,8 +891,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
|
|
||||||
input_name = "input_ids"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SpeechT5ForTextToSpeechTester(self)
|
self.model_tester = SpeechT5ForTextToSpeechTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
||||||
@ -1441,8 +1435,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
|
|
||||||
input_name = "input_values"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SpeechT5ForSpeechToSpeechTester(self)
|
self.model_tester = SpeechT5ForSpeechToSpeechTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
|
||||||
@ -1854,8 +1846,6 @@ class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
has_attentions = False
|
has_attentions = False
|
||||||
|
|
||||||
input_name = "spectrogram"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = SpeechT5HifiGanTester(self)
|
self.model_tester = SpeechT5HifiGanTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=SpeechT5HifiGanConfig)
|
self.config_tester = ConfigTester(self, config_class=SpeechT5HifiGanConfig)
|
||||||
|
@ -113,7 +113,7 @@ class StableLmModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -107,7 +107,7 @@ class Starcoder2ModelTester:
|
|||||||
|
|
||||||
input_mask = None
|
input_mask = None
|
||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
token_type_ids = None
|
token_type_ids = None
|
||||||
if self.use_token_type_ids:
|
if self.use_token_type_ids:
|
||||||
|
@ -470,7 +470,7 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
|||||||
self.assertListEqual(expected_output_string, output_strings_xla)
|
self.assertListEqual(expected_output_string, output_strings_xla)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_greedy_generate(self):
|
def test_t5_greedy_generate(self):
|
||||||
model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||||
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||||
|
|
||||||
@ -520,7 +520,7 @@ class TFT5GenerationIntegrationTests(unittest.TestCase):
|
|||||||
self.assertListEqual(expected_output_string_xla, output_strings_xla)
|
self.assertListEqual(expected_output_string_xla, output_strings_xla)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_sample_generate(self):
|
def test_t5_sample_generate(self):
|
||||||
model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
|
||||||
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||||
|
|
||||||
|
@ -118,8 +118,6 @@ class UnivNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
has_attentions = False
|
has_attentions = False
|
||||||
|
|
||||||
input_name = "input_features"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = UnivNetModelTester(self)
|
self.model_tester = UnivNetModelTester(self)
|
||||||
self.config_tester = ConfigTester(
|
self.config_tester = ConfigTester(
|
||||||
|
@ -167,8 +167,6 @@ class VitsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
has_attentions = False
|
has_attentions = False
|
||||||
|
|
||||||
input_name = "input_ids"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = VitsModelTester(self)
|
self.model_tester = VitsModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=VitsConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=VitsConfig, hidden_size=37)
|
||||||
|
@ -395,8 +395,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
# `0.5` is for `test_disk_offload` (which also works for `test_model_parallelism`)
|
# `0.5` is for `test_disk_offload` (which also works for `test_model_parallelism`)
|
||||||
model_split_percents = [0.5, 0.8, 0.9]
|
model_split_percents = [0.5, 0.8, 0.9]
|
||||||
|
|
||||||
input_name = "input_features"
|
|
||||||
|
|
||||||
# TODO: Fix the failed tests
|
# TODO: Fix the failed tests
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
@ -868,48 +866,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
|
||||||
batch_size, mel, seq_length = input_ids.shape
|
|
||||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
|
||||||
gen_len = (
|
|
||||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# scores
|
|
||||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
|
||||||
|
|
||||||
# Attentions
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_attention_for_generate(
|
|
||||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
|
||||||
)
|
|
||||||
# decoder
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
num_sequences_in_output,
|
|
||||||
output.decoder_attentions,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hidden States
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_hidden_states_for_generate(
|
|
||||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
num_sequences_in_output,
|
|
||||||
output.decoder_hidden_states,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@ -3511,8 +3467,6 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
input_name = "input_features"
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = WhisperEncoderModelTester(self)
|
self.model_tester = WhisperEncoderModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||||
|
Loading…
Reference in New Issue
Block a user