Better test name and enable pipeline test for pix2struct (#24377)

* best test name forever

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-06-20 18:29:30 +02:00 committed by GitHub
parent 6950f70b38
commit 297d769d0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -35,6 +35,7 @@ from ...test_modeling_common import (
ids_tensor, ids_tensor,
random_attention_mask, random_attention_mask,
) )
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
@ -354,7 +355,7 @@ class Pix2StructTextModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
class Pix2StructTextImageModelsModelTester: class Pix2StructModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None: if text_kwargs is None:
text_kwargs = {} text_kwargs = {}
@ -394,8 +395,9 @@ class Pix2StructTextImageModelsModelTester:
@require_torch @require_torch
class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else () all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
@ -404,7 +406,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = False test_torchscript = False
def setUp(self): def setUp(self):
self.model_tester = Pix2StructTextImageModelsModelTester(self) self.model_tester = Pix2StructModelTester(self)
def test_model(self): def test_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()