mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: skip tests on unsupported models instead of passing (#27265)
This commit is contained in:
parent
26d8d5f211
commit
90b4adc1f1
@ -749,8 +749,7 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
# only relevant if model has "use_cache"
|
||||
return
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@ -983,8 +982,7 @@ class GenerationTesterMixin:
|
||||
config.forced_eos_token_id = None
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
# only relevant if model has "use_cache"
|
||||
return
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -1420,13 +1418,13 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
self.skipTest("This model doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
@ -1441,14 +1439,14 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
self.skipTest("This model doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
@ -1472,18 +1470,16 @@ class GenerationTesterMixin:
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
|
||||
):
|
||||
return
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
|
||||
self.skipTest("TODO: fix me")
|
||||
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@ -1510,8 +1506,6 @@ class GenerationTesterMixin:
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
return
|
||||
|
||||
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
@ -1522,15 +1516,13 @@ class GenerationTesterMixin:
|
||||
# - assisted_decoding does not support `batch_size > 1`
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
|
||||
):
|
||||
return
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
|
||||
# This for loop is a naive and temporary effort to make the test less flaky.
|
||||
failed = 0
|
||||
@ -1540,7 +1532,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@ -1587,24 +1579,21 @@ class GenerationTesterMixin:
|
||||
def test_assisted_decoding_sample(self):
|
||||
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
|
||||
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
return
|
||||
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"]
|
||||
):
|
||||
return
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
|
||||
# enable cache
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
@ -1716,7 +1705,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
if "use_cache" not in inputs:
|
||||
@ -1725,7 +1714,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
||||
if "past_key_values" not in outputs:
|
||||
return
|
||||
self.skipTest("This model doesn't return `past_key_values`")
|
||||
|
||||
num_hidden_layers = (
|
||||
getattr(config, "decoder_layers", None)
|
||||
@ -1832,18 +1821,15 @@ class GenerationTesterMixin:
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: old models with unique inputs/caches/others
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
|
||||
return
|
||||
# may fix in the future: needs modeling or test input preparation fixes for compatibility
|
||||
self.skipTest("Won't fix: old model with unique inputs/caches/other")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||
return
|
||||
self.skipTest("TODO: needs modeling or test input preparation fixes for compatibility")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
# Let's make it always:
|
||||
# 1. use cache (for obvious reasons)
|
||||
@ -1862,10 +1848,10 @@ class GenerationTesterMixin:
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
|
||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
||||
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
||||
outputs = model(**inputs)
|
||||
if "past_key_values" not in outputs:
|
||||
return
|
||||
self.skipTest("This model doesn't return `past_key_values`")
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
|
||||
|
Loading…
Reference in New Issue
Block a user