mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Remove redundant test_head_masking = True
flags in test files (#9858)
* Remove redundant test_head_masking = True flags * Remove all redundant test_head_masking flags in PyTorch test_modeling_* files * Make test_head_masking = True as a default choice in test_modeling_tf_commong.py * Remove all redundant test_head_masking flags in TensorFlow test_modeling_tf_* files * Put back test_head_masking=False fot TFT5 models
This commit is contained in:
parent
caddf9126b
commit
4c3ae89ad3
@ -402,7 +402,6 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
@ -206,7 +206,6 @@ class BlenderbotModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
all_generative_model_classes = (BlenderbotForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
@ -214,7 +214,6 @@ class BlenderbotSmallModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
|
||||
all_generative_model_classes = (BlenderbotSmallForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
@ -209,7 +209,6 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DistilBertModelTester(self)
|
||||
|
@ -527,10 +527,6 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
|
||||
# overwrite function because qa models takes different input label shape
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
@ -223,7 +223,6 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
@ -219,7 +219,6 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
all_generative_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
@ -207,7 +207,6 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
|
@ -178,7 +178,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBartModelTester(self)
|
||||
|
@ -177,7 +177,6 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotModelTester(self)
|
||||
|
@ -179,7 +179,6 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotSmallModelTester(self)
|
||||
|
@ -75,6 +75,7 @@ class TFModelTesterMixin:
|
||||
all_model_classes = ()
|
||||
all_generative_model_classes = ()
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
is_encoder_decoder = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
|
||||
|
@ -179,7 +179,6 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMarianModelTester(self)
|
||||
|
@ -181,7 +181,6 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMBartModelTester(self)
|
||||
|
@ -177,7 +177,6 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFPegasusModelTester(self)
|
||||
|
Loading…
Reference in New Issue
Block a user