Fix T5 model parallel tes (#9107)

k
This commit is contained in:
Lysandre Debut 2020-12-15 09:51:12 -05:00 committed by GitHub
parent 59da3f2700
commit 6ccea0486f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -484,9 +484,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
all_parallelizable_model_classes = (
(T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else ()
)
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
test_pruning = False
test_torchscript = True
test_resize_embeddings = True
@ -689,6 +687,8 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = True
test_resize_embeddings = False
test_model_parallel = True
all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else ()
def setUp(self):
self.model_tester = T5EncoderOnlyModelTester(self)