mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix flaky test_batching_equivalence
(#35564)
* yes! * oh no!!! * oh no!!! * style * oh no!!! * oh no!!! * oh no!!! * oh no!!! --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
4adc415b6d
commit
1b2f942af7
@ -1409,17 +1409,42 @@ def assert_screenout(out, what):
|
||||
|
||||
|
||||
def set_model_tester_for_less_flaky_test(test_case):
|
||||
if hasattr(test_case.model_tester, "num_hidden_layers"):
|
||||
test_case.model_tester.num_hidden_layers = 1
|
||||
target_num_hidden_layers = 1
|
||||
# TODO (if possible): Avoid exceptional cases
|
||||
exceptional_classes = [
|
||||
"ZambaModelTester",
|
||||
"RwkvModelTester",
|
||||
"AriaVisionText2TextModelTester",
|
||||
"GPTNeoModelTester",
|
||||
"DPTModelTester",
|
||||
]
|
||||
if test_case.model_tester.__class__.__name__ in exceptional_classes:
|
||||
target_num_hidden_layers = None
|
||||
if hasattr(test_case.model_tester, "out_features") or hasattr(test_case.model_tester, "out_indices"):
|
||||
target_num_hidden_layers = None
|
||||
|
||||
if hasattr(test_case.model_tester, "num_hidden_layers") and target_num_hidden_layers is not None:
|
||||
test_case.model_tester.num_hidden_layers = target_num_hidden_layers
|
||||
if (
|
||||
hasattr(test_case.model_tester, "vision_config")
|
||||
and "num_hidden_layers" in test_case.model_tester.vision_config
|
||||
and target_num_hidden_layers is not None
|
||||
):
|
||||
test_case.model_tester.vision_config = copy.deepcopy(test_case.model_tester.vision_config)
|
||||
test_case.model_tester.vision_config["num_hidden_layers"] = 1
|
||||
if hasattr(test_case.model_tester, "text_config") and "num_hidden_layers" in test_case.model_tester.text_config:
|
||||
test_case.model_tester.vision_config["num_hidden_layers"] = target_num_hidden_layers
|
||||
if (
|
||||
hasattr(test_case.model_tester, "text_config")
|
||||
and "num_hidden_layers" in test_case.model_tester.text_config
|
||||
and target_num_hidden_layers is not None
|
||||
):
|
||||
test_case.model_tester.text_config = copy.deepcopy(test_case.model_tester.text_config)
|
||||
test_case.model_tester.text_config["num_hidden_layers"] = 1
|
||||
test_case.model_tester.text_config["num_hidden_layers"] = target_num_hidden_layers
|
||||
|
||||
# A few model class specific handling
|
||||
|
||||
# For Albert
|
||||
if hasattr(test_case.model_tester, "num_hidden_groups"):
|
||||
test_case.model_tester.num_hidden_groups = test_case.model_tester.num_hidden_layers
|
||||
|
||||
|
||||
def set_config_for_less_flaky_test(config):
|
||||
|
@ -82,7 +82,6 @@ class UperNetModelTester:
|
||||
self.out_features = out_features
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
self.num_hidden_layers = num_stages
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
@ -816,7 +816,10 @@ class ModelTesterMixin:
|
||||
),
|
||||
)
|
||||
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
|
||||
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
set_config_for_less_flaky_test(config)
|
||||
equivalence = get_tensor_equivalence_function(batched_input)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@ -827,6 +830,7 @@ class ModelTesterMixin:
|
||||
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
|
||||
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
set_model_for_less_flaky_test(model)
|
||||
|
||||
batch_size = self.model_tester.batch_size
|
||||
single_row_input = {}
|
||||
|
Loading…
Reference in New Issue
Block a user