Better test name and enable pipeline test for pix2struct (#24377)

* best test name forever

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-06-20 18:29:30 +02:00 committed by GitHub
parent 6950f70b38
commit 297d769d0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -35,6 +35,7 @@ from ...test_modeling_common import (
ids_tensor,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
@ -354,7 +355,7 @@ class Pix2StructTextModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
class Pix2StructTextImageModelsModelTester:
class Pix2StructModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None:
text_kwargs = {}
@ -394,8 +395,9 @@ class Pix2StructTextImageModelsModelTester:
@require_torch
class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
test_pruning = False
@ -404,7 +406,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = False
def setUp(self):
self.model_tester = Pix2StructTextImageModelsModelTester(self)
self.model_tester = Pix2StructModelTester(self)
def test_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()