mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix flaky test_custom_4d_attention_mask
(#35606)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
f63829c87b
commit
bbc00046b9
@ -1431,14 +1431,20 @@ def set_model_tester_for_less_flaky_test(test_case):
|
||||
and target_num_hidden_layers is not None
|
||||
):
|
||||
test_case.model_tester.vision_config = copy.deepcopy(test_case.model_tester.vision_config)
|
||||
test_case.model_tester.vision_config["num_hidden_layers"] = target_num_hidden_layers
|
||||
if isinstance(test_case.model_tester.vision_config, dict):
|
||||
test_case.model_tester.vision_config["num_hidden_layers"] = 1
|
||||
else:
|
||||
test_case.model_tester.vision_config.num_hidden_layers = 1
|
||||
if (
|
||||
hasattr(test_case.model_tester, "text_config")
|
||||
and "num_hidden_layers" in test_case.model_tester.text_config
|
||||
and target_num_hidden_layers is not None
|
||||
):
|
||||
test_case.model_tester.text_config = copy.deepcopy(test_case.model_tester.text_config)
|
||||
test_case.model_tester.text_config["num_hidden_layers"] = target_num_hidden_layers
|
||||
if isinstance(test_case.model_tester.text_config, dict):
|
||||
test_case.model_tester.text_config["num_hidden_layers"] = 1
|
||||
else:
|
||||
test_case.model_tester.text_config.num_hidden_layers = 1
|
||||
|
||||
# A few model class specific handling
|
||||
|
||||
|
@ -4707,13 +4707,17 @@ class ModelTesterMixin:
|
||||
reason="Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
|
||||
)
|
||||
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
set_config_for_less_flaky_test(config)
|
||||
if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
|
||||
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
set_model_for_less_flaky_test(model)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user