Generate: skip tests on unsupported models instead of passing (#27265)

This commit is contained in:
Joao Gante 2023-11-07 12:08:28 +00:00 committed by GitHub
parent 26d8d5f211
commit 90b4adc1f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -749,8 +749,7 @@ class GenerationTesterMixin:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
# only relevant if model has "use_cache"
return
self.skipTest("This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@ -983,8 +982,7 @@ class GenerationTesterMixin:
config.forced_eos_token_id = None
if not hasattr(config, "use_cache"):
# only relevant if model has "use_cache"
return
self.skipTest("This model doesn't support caching")
model = model_class(config).to(torch_device).eval()
if model.config.is_encoder_decoder:
@ -1420,13 +1418,13 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
self.skipTest("Won't fix: old model with different cache format")
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@ -1441,14 +1439,14 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
self.skipTest("Won't fix: old model with different cache format")
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@ -1472,18 +1470,16 @@ class GenerationTesterMixin:
def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if any(
model_name in model_class.__name__.lower()
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
):
return
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
self.skipTest("Won't fix: old model with different cache format")
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
self.skipTest("TODO: fix me")
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@ -1510,8 +1506,6 @@ class GenerationTesterMixin:
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
return
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
@ -1522,15 +1516,13 @@ class GenerationTesterMixin:
# - assisted_decoding does not support `batch_size > 1`
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet"]
):
return
self.skipTest("May fix in the future: need model-specific fixes")
# This for loop is a naive and temporary effort to make the test less flaky.
failed = 0
@ -1540,7 +1532,7 @@ class GenerationTesterMixin:
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@ -1587,24 +1579,21 @@ class GenerationTesterMixin:
def test_assisted_decoding_sample(self):
# Seeded assisted decoding will not match sample for the same seed, as the forward pass does not return the
# exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking).
for model_class in self.all_generative_model_classes:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
self.skipTest("Won't fix: old model with different cache format")
if any(
model_name in model_class.__name__.lower()
for model_name in ["bigbirdpegasus", "led", "mega", "speech2text", "git", "prophetnet", "seamlessm4t"]
):
return
self.skipTest("May fix in the future: need model-specific fixes")
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")
config.use_cache = True
config.is_decoder = True
@ -1716,7 +1705,7 @@ class GenerationTesterMixin:
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")
model = model_class(config).to(torch_device)
if "use_cache" not in inputs:
@ -1725,7 +1714,7 @@ class GenerationTesterMixin:
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs:
return
self.skipTest("This model doesn't return `past_key_values`")
num_hidden_layers = (
getattr(config, "decoder_layers", None)
@ -1832,18 +1821,15 @@ class GenerationTesterMixin:
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
# won't fix: old models with unique inputs/caches/others
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
return
# may fix in the future: needs modeling or test input preparation fixes for compatibility
self.skipTest("Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
return
self.skipTest("TODO: needs modeling or test input preparation fixes for compatibility")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
return
self.skipTest("This model doesn't support caching")
# Let's make it always:
# 1. use cache (for obvious reasons)
@ -1862,10 +1848,10 @@ class GenerationTesterMixin:
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
if "past_key_values" not in outputs:
return
self.skipTest("This model doesn't return `past_key_values`")
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)