mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[generation] bring back tests on vision models (#38603)
* bring back geenration tests on VLMs * remove head mask tests overwritten
This commit is contained in:
parent
90c4b90a10
commit
dbfc79c17c
@ -28,10 +28,11 @@ class DecoderConfig(PretrainedConfig):
|
|||||||
|
|
||||||
model_type = "fsmt_decoder"
|
model_type = "fsmt_decoder"
|
||||||
|
|
||||||
def __init__(self, vocab_size=0, bos_token_id=0):
|
def __init__(self, vocab_size=0, bos_token_id=0, is_encoder_decoder=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder
|
||||||
|
|
||||||
|
|
||||||
class FSMTConfig(PretrainedConfig):
|
class FSMTConfig(PretrainedConfig):
|
||||||
@ -187,7 +188,9 @@ class FSMTConfig(PretrainedConfig):
|
|||||||
self.init_std = init_std # Normal(0, this parameter)
|
self.init_std = init_std # Normal(0, this parameter)
|
||||||
self.activation_function = activation_function
|
self.activation_function = activation_function
|
||||||
|
|
||||||
self.decoder = DecoderConfig(vocab_size=tgt_vocab_size, bos_token_id=eos_token_id)
|
self.decoder = DecoderConfig(
|
||||||
|
vocab_size=tgt_vocab_size, bos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder
|
||||||
|
)
|
||||||
if "decoder" in common_kwargs:
|
if "decoder" in common_kwargs:
|
||||||
del common_kwargs["decoder"]
|
del common_kwargs["decoder"]
|
||||||
|
|
||||||
|
@ -499,7 +499,7 @@ class GenerationTesterMixin:
|
|||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
|
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||||
@ -523,7 +523,7 @@ class GenerationTesterMixin:
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
@ -563,7 +563,7 @@ class GenerationTesterMixin:
|
|||||||
use_cache=True, # Enable cache
|
use_cache=True, # Enable cache
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@ -580,7 +580,7 @@ class GenerationTesterMixin:
|
|||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
|
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||||
@ -605,7 +605,7 @@ class GenerationTesterMixin:
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
@ -630,7 +630,7 @@ class GenerationTesterMixin:
|
|||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs)
|
output_generate = self._beam_search_generate(model=model, inputs_dict=inputs_dict, beam_kwargs=beam_kwargs)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||||
@ -655,7 +655,7 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
@ -704,7 +704,7 @@ class GenerationTesterMixin:
|
|||||||
use_cache=True, # Enable cache
|
use_cache=True, # Enable cache
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@ -757,7 +757,7 @@ class GenerationTesterMixin:
|
|||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||||
@ -784,7 +784,7 @@ class GenerationTesterMixin:
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
@ -838,7 +838,7 @@ class GenerationTesterMixin:
|
|||||||
inputs_dict=inputs_dict,
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||||
@ -851,7 +851,7 @@ class GenerationTesterMixin:
|
|||||||
inputs_dict=inputs_dict,
|
inputs_dict=inputs_dict,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||||
@ -876,7 +876,7 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
@ -921,7 +921,7 @@ class GenerationTesterMixin:
|
|||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||||
@ -945,7 +945,7 @@ class GenerationTesterMixin:
|
|||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||||
@ -985,7 +985,7 @@ class GenerationTesterMixin:
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
@ -1029,7 +1029,7 @@ class GenerationTesterMixin:
|
|||||||
inputs_dict=inputs_dict,
|
inputs_dict=inputs_dict,
|
||||||
use_cache=True, # Enable cache
|
use_cache=True, # Enable cache
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
|
||||||
@ -1065,7 +1065,7 @@ class GenerationTesterMixin:
|
|||||||
use_cache=True, # Enable cache
|
use_cache=True, # Enable cache
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@ -1297,7 +1297,7 @@ class GenerationTesterMixin:
|
|||||||
config._attn_implementation = "eager"
|
config._attn_implementation = "eager"
|
||||||
|
|
||||||
# Encoder-decoder models are not supported
|
# Encoder-decoder models are not supported
|
||||||
if config.is_encoder_decoder:
|
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.skipTest("DoLa is not supported for encoder-decoder models")
|
self.skipTest("DoLa is not supported for encoder-decoder models")
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
@ -1427,52 +1427,6 @@ class GenerationTesterMixin:
|
|||||||
# PLD shouldn't propose any new tokens based on eos-match
|
# PLD shouldn't propose any new tokens based on eos-match
|
||||||
self.assertTrue(output_prompt_lookup.shape[-1] == 10)
|
self.assertTrue(output_prompt_lookup.shape[-1] == 10)
|
||||||
|
|
||||||
@pytest.mark.generate
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
|
||||||
config._attn_implementation = "eager" # head mask works only in eager mode and will be removed soon
|
|
||||||
text_config = config.get_text_config()
|
|
||||||
if self.has_attentions:
|
|
||||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
|
||||||
|
|
||||||
# We want to test only encoder-decoder models
|
|
||||||
if not text_config.is_encoder_decoder:
|
|
||||||
continue
|
|
||||||
model = model_class(config).to(torch_device)
|
|
||||||
|
|
||||||
head_masking = {
|
|
||||||
"head_mask": torch.zeros(
|
|
||||||
text_config.encoder_layers, text_config.encoder_attention_heads, device=torch_device
|
|
||||||
),
|
|
||||||
"decoder_head_mask": torch.zeros(
|
|
||||||
text_config.decoder_layers, text_config.decoder_attention_heads, device=torch_device
|
|
||||||
),
|
|
||||||
"cross_attn_head_mask": torch.zeros(
|
|
||||||
text_config.decoder_layers, text_config.decoder_attention_heads, device=torch_device
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
signature = inspect.signature(model.forward)
|
|
||||||
# We want to test only models where encoder/decoder head masking is implemented
|
|
||||||
if not set(head_masking.keys()) < {*signature.parameters.keys()}:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
|
||||||
out = model.generate(
|
|
||||||
num_beams=1,
|
|
||||||
output_attentions=self.has_attentions,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
remove_invalid_values=True,
|
|
||||||
**{name: mask},
|
|
||||||
**inputs_dict,
|
|
||||||
)
|
|
||||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
|
||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
|
||||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_left_padding_compatibility(self):
|
def test_left_padding_compatibility(self):
|
||||||
# NOTE: left-padding results in small numerical differences. This is expected.
|
# NOTE: left-padding results in small numerical differences. This is expected.
|
||||||
@ -1491,7 +1445,7 @@ class GenerationTesterMixin:
|
|||||||
decoder_only_classes = []
|
decoder_only_classes = []
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, _ = self.prepare_config_and_inputs_for_generate()
|
config, _ = self.prepare_config_and_inputs_for_generate()
|
||||||
if config.is_encoder_decoder:
|
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
decoder_only_classes.append(model_class)
|
decoder_only_classes.append(model_class)
|
||||||
@ -1696,7 +1650,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
|
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
|
||||||
# decoder)
|
# decoder)
|
||||||
if config.is_encoder_decoder:
|
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
continue
|
continue
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
@ -1790,7 +1744,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
|
||||||
if config.is_encoder_decoder:
|
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
@ -1952,7 +1906,7 @@ class GenerationTesterMixin:
|
|||||||
if "token_type_ids" in inputs_dict:
|
if "token_type_ids" in inputs_dict:
|
||||||
del inputs_dict["token_type_ids"]
|
del inputs_dict["token_type_ids"]
|
||||||
|
|
||||||
if config.is_encoder_decoder:
|
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.skipTest(reason="This model is encoder-decoder")
|
self.skipTest(reason="This model is encoder-decoder")
|
||||||
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
|
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
|
||||||
# but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern
|
# but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern
|
||||||
@ -2031,7 +1985,7 @@ class GenerationTesterMixin:
|
|||||||
set_config_for_less_flaky_test(config)
|
set_config_for_less_flaky_test(config)
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
main_input = inputs_dict[model_class.main_input_name]
|
||||||
|
|
||||||
if config.is_encoder_decoder:
|
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||||
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
@ -2183,7 +2137,7 @@ class GenerationTesterMixin:
|
|||||||
if not has_defined_cache_implementation:
|
if not has_defined_cache_implementation:
|
||||||
decoder_cache = (
|
decoder_cache = (
|
||||||
gen_out.past_key_values.self_attention_cache
|
gen_out.past_key_values.self_attention_cache
|
||||||
if config.is_encoder_decoder
|
if config.get_text_config(decoder=True).is_encoder_decoder
|
||||||
else gen_out.past_key_values
|
else gen_out.past_key_values
|
||||||
)
|
)
|
||||||
self.assertTrue(isinstance(decoder_cache, DynamicCache))
|
self.assertTrue(isinstance(decoder_cache, DynamicCache))
|
||||||
@ -2209,7 +2163,7 @@ class GenerationTesterMixin:
|
|||||||
# sanity checks
|
# sanity checks
|
||||||
decoder_cache = (
|
decoder_cache = (
|
||||||
gen_out.past_key_values.self_attention_cache
|
gen_out.past_key_values.self_attention_cache
|
||||||
if config.is_encoder_decoder
|
if config.get_text_config(decoder=True).is_encoder_decoder
|
||||||
else gen_out.past_key_values
|
else gen_out.past_key_values
|
||||||
)
|
)
|
||||||
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
self.assertFalse(isinstance(decoder_cache, DynamicCache))
|
||||||
@ -2283,7 +2237,7 @@ class GenerationTesterMixin:
|
|||||||
else:
|
else:
|
||||||
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
@ -5154,7 +5108,6 @@ class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase):
|
|||||||
|
|
||||||
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
@parameterized.expand([(is_sklearn_available(),), (False,)])
|
||||||
def test_update_candidate_strategy_no_matches_short(self, sklearn_available):
|
def test_update_candidate_strategy_no_matches_short(self, sklearn_available):
|
||||||
print("test_update_candidate_strategy_no_matches_short")
|
|
||||||
self.original_matches = []
|
self.original_matches = []
|
||||||
self.candidate_generator.matches = self.original_matches
|
self.candidate_generator.matches = self.original_matches
|
||||||
self.num_matches = 0
|
self.num_matches = 0
|
||||||
|
@ -468,13 +468,6 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
def test_load_save_without_tied_weights(self):
|
def test_load_save_without_tied_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
# overwritten to temporarily switch the attention type to `original_full`
|
|
||||||
original_self_attention_type = self.model_tester.attention_type
|
|
||||||
self.model_tester.attention_type = "original_full"
|
|
||||||
super().test_generate_with_head_masking()
|
|
||||||
self.model_tester.attention_type = original_self_attention_type
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
@ -782,7 +782,7 @@ class BlipVQAModelTester:
|
|||||||
@require_vision
|
@require_vision
|
||||||
class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
|
class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (BlipForQuestionAnswering,) if is_torch_available() else ()
|
all_model_classes = (BlipForQuestionAnswering,) if is_torch_available() else ()
|
||||||
# Doesn't run generation tests. There are interface mismatches when using `generate` -- TODO @gante
|
# Doesn't run generation tests due to custom generation logic -- won't fix
|
||||||
all_generative_model_classes = ()
|
all_generative_model_classes = ()
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
@ -1091,7 +1091,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (BlipForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (BlipForConditionalGeneration,) if is_torch_available() else ()
|
||||||
# Doesn't run generation tests. There are interface mismatches when using `generate` -- TODO @gante
|
# Doesn't run generation tests due to custom generation logic -- wont fix
|
||||||
all_generative_model_classes = ()
|
all_generative_model_classes = ()
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
@ -774,6 +774,7 @@ class Blip2TextModelTester:
|
|||||||
bos_token_id=self.pad_token_id,
|
bos_token_id=self.pad_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
decoder_start_token_id=self.decoder_start_token_id,
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
is_encoder_decoder=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -795,6 +796,9 @@ class Blip2ModelTester:
|
|||||||
self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs)
|
self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs)
|
||||||
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
||||||
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
|
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
|
||||||
|
self.encoder_seq_length = (
|
||||||
|
self.text_model_tester.encoder_seq_length + num_query_tokens
|
||||||
|
) # need enc seq_length for gen tests
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.num_query_tokens = num_query_tokens
|
self.num_query_tokens = num_query_tokens
|
||||||
|
|
||||||
@ -859,11 +863,9 @@ class Blip2ModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
||||||
additional_model_inputs = ["input_ids", "decoder_input_ids"]
|
additional_model_inputs = ["input_ids", "decoder_input_ids"]
|
||||||
# Doesn't run generation tests. TODO: fix generation tests for Blip2ForConditionalGeneration
|
|
||||||
all_generative_model_classes = ()
|
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": Blip2Model,
|
"feature-extraction": Blip2Model,
|
||||||
|
@ -324,10 +324,8 @@ class IdeficsModelTester:
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else ()
|
all_model_classes = (IdeficsModel, IdeficsForVisionText2Text) if is_torch_available() else ()
|
||||||
# Doesn't run generation tests here -- idefics has a dedicated tester for generation tests below
|
|
||||||
all_generative_model_classes = ()
|
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{"feature-extraction": IdeficsModel, "image-text-to-text": IdeficsForVisionText2Text}
|
{"feature-extraction": IdeficsModel, "image-text-to-text": IdeficsForVisionText2Text}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -336,6 +334,7 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
has_attentions = False # only supports SDOA and thus no attention probs returned
|
||||||
|
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||||
@ -494,6 +493,31 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
def test_retain_grad_hidden_states_attentions(self):
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@unittest.skip(reason="""IDEFICS cannot generate with no images provided!""")
|
||||||
|
def test_generate_without_input_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@unittest.skip(reason="""IDEFICS cannot generate with no images provided!""")
|
||||||
|
def test_generate_continue_from_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@unittest.skip(reason="""IDEFICS cannot do contrastive generation yet and it is not worth fixing""")
|
||||||
|
def test_contrastive_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@unittest.skip(reason="""IDEFICS cannot do contrastive generation yet and it is not worth fixing""")
|
||||||
|
def test_contrastive_generate_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@unittest.skip(reason="""IDEFICS cannot do contrastive generation yet and it is not worth fixing""")
|
||||||
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
config.return_dict = True
|
config.return_dict = True
|
||||||
|
@ -626,40 +626,6 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
model = LongT5Model.from_pretrained(model_name)
|
model = LongT5Model.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
config = config_and_inputs[0]
|
|
||||||
max_length = config_and_inputs[1].shape[-1] + 3
|
|
||||||
model = LongT5ForConditionalGeneration(config).eval()
|
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
head_masking = {
|
|
||||||
"head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device),
|
|
||||||
"decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
"cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
}
|
|
||||||
|
|
||||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
|
||||||
head_masks = {name: mask}
|
|
||||||
# Explicitly pass decoder_head_mask as it is required from LONGT5 model when head_mask specified
|
|
||||||
if name == "head_mask":
|
|
||||||
head_masks["decoder_head_mask"] = torch.ones(
|
|
||||||
config.num_decoder_layers, config.num_heads, device=torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
config_and_inputs[1],
|
|
||||||
num_beams=1,
|
|
||||||
max_length=max_length,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
**head_masks,
|
|
||||||
)
|
|
||||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
|
||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
|
||||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
|
||||||
|
|
||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
if not self.has_attentions:
|
if not self.has_attentions:
|
||||||
self.skipTest(reason="has_attentions is set to False")
|
self.skipTest(reason="has_attentions is set to False")
|
||||||
|
@ -868,40 +868,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
model = MT5Model.from_pretrained(model_name)
|
model = MT5Model.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
config = config_and_inputs[0]
|
|
||||||
max_length = config_and_inputs[1].shape[-1] + 3
|
|
||||||
model = MT5ForConditionalGeneration(config).eval()
|
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
head_masking = {
|
|
||||||
"head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device),
|
|
||||||
"decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
"cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
}
|
|
||||||
|
|
||||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
|
||||||
head_masks = {name: mask}
|
|
||||||
# Explicitly pass decoder_head_mask as it is required from MT5 model when head_mask specified
|
|
||||||
if name == "head_mask":
|
|
||||||
head_masks["decoder_head_mask"] = torch.ones(
|
|
||||||
config.num_decoder_layers, config.num_heads, device=torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
config_and_inputs[1],
|
|
||||||
num_beams=1,
|
|
||||||
max_length=max_length,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
**head_masks,
|
|
||||||
)
|
|
||||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
|
||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
|
||||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from tests.models.t5.test_modeling_t5.T5EncoderOnlyModelTester with T5->MT5
|
# Copied from tests.models.t5.test_modeling_t5.T5EncoderOnlyModelTester with T5->MT5
|
||||||
class MT5EncoderOnlyModelTester:
|
class MT5EncoderOnlyModelTester:
|
||||||
|
@ -1117,10 +1117,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
self.assertIsNotNone(encoder_hidden_states.grad)
|
self.assertIsNotNone(encoder_hidden_states.grad)
|
||||||
self.assertIsNotNone(encoder_attentions.grad)
|
self.assertIsNotNone(encoder_attentions.grad)
|
||||||
|
|
||||||
@unittest.skip(reason="Generating with head_masking has not been implemented for ProphetNet models yet.")
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
@ -741,10 +741,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase, Generatio
|
|||||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||||
module.masked_spec_embed.data.fill_(3)
|
module.masked_spec_embed.data.fill_(3)
|
||||||
|
|
||||||
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
|
@unittest.skip(reason="Temporarily broken") # TODO (joao, eustache): have a look at this test
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
@ -709,40 +709,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||||||
model = SwitchTransformersModel.from_pretrained(model_name)
|
model = SwitchTransformersModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
config = config_and_inputs[0]
|
|
||||||
max_length = config_and_inputs[1].shape[-1] + 3
|
|
||||||
model = SwitchTransformersForConditionalGeneration(config).eval()
|
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
head_masking = {
|
|
||||||
"head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device),
|
|
||||||
"decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
"cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
}
|
|
||||||
|
|
||||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
|
||||||
head_masks = {name: mask}
|
|
||||||
# Explicitly pass decoder_head_mask as it is required from SWITCH_TRANSFORMERS model when head_mask specified
|
|
||||||
if name == "head_mask":
|
|
||||||
head_masks["decoder_head_mask"] = torch.ones(
|
|
||||||
config.num_decoder_layers, config.num_heads, device=torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
config_and_inputs[1],
|
|
||||||
num_beams=1,
|
|
||||||
max_length=max_length,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
**head_masks,
|
|
||||||
)
|
|
||||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
|
||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
|
||||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
|
reason="This architecture has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
|
||||||
)
|
)
|
||||||
|
@ -873,40 +873,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
model = T5Model.from_pretrained(model_name)
|
model = T5Model.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
config = config_and_inputs[0]
|
|
||||||
max_length = config_and_inputs[1].shape[-1] + 3
|
|
||||||
model = T5ForConditionalGeneration(config).eval()
|
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
head_masking = {
|
|
||||||
"head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device),
|
|
||||||
"decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
"cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
}
|
|
||||||
|
|
||||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
|
||||||
head_masks = {name: mask}
|
|
||||||
# Explicitly pass decoder_head_mask as it is required from T5 model when head_mask specified
|
|
||||||
if name == "head_mask":
|
|
||||||
head_masks["decoder_head_mask"] = torch.ones(
|
|
||||||
config.num_decoder_layers, config.num_heads, device=torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
config_and_inputs[1],
|
|
||||||
num_beams=1,
|
|
||||||
max_length=max_length,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
**head_masks,
|
|
||||||
)
|
|
||||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
|
||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
|
||||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
class T5EncoderOnlyModelTester:
|
class T5EncoderOnlyModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -419,10 +419,6 @@ class UdopModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
model = UdopForConditionalGeneration.from_pretrained(model_name)
|
model = UdopForConditionalGeneration.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@unittest.skip(reason="TODO: Fix me @joao")
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="TODO: Fix me @joao")
|
@unittest.skip(reason="TODO: Fix me @joao")
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
@ -489,39 +489,6 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||||
|
|
||||||
def test_generate_with_head_masking(self):
|
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
config = config_and_inputs[0]
|
|
||||||
model = UMT5ForConditionalGeneration(config).eval()
|
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
head_masking = {
|
|
||||||
"head_mask": torch.zeros(config.num_layers, config.num_heads, device=torch_device),
|
|
||||||
"decoder_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
"cross_attn_head_mask": torch.zeros(config.num_decoder_layers, config.num_heads, device=torch_device),
|
|
||||||
}
|
|
||||||
|
|
||||||
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
|
|
||||||
head_masks = {name: mask}
|
|
||||||
# Explicitly pass decoder_head_mask as it is required from T5 model when head_mask specified
|
|
||||||
if name == "head_mask":
|
|
||||||
head_masks["decoder_head_mask"] = torch.ones(
|
|
||||||
config.num_decoder_layers, config.num_heads, device=torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
out = model.generate(
|
|
||||||
config_and_inputs[1]["input_ids"],
|
|
||||||
num_beams=1,
|
|
||||||
max_length=3,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
**head_masks,
|
|
||||||
)
|
|
||||||
# We check the state of decoder_attentions and cross_attentions just from the last step
|
|
||||||
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
|
|
||||||
self.assertEqual(sum([w.sum().item() for w in attn_weights]), 0.0)
|
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user