Move the model type check (#19027)

Co-authored-by: Ankur Goyal <ankur@impira.com>
This commit is contained in:
Ankur Goyal 2022-09-26 06:43:34 -07:00 committed by GitHub
parent ea75e9f10e
commit 216b2f9e80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -116,16 +116,17 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING)
if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig":
self.model_type = ModelType.VisionEncoderDecoder
if self.model.config.encoder.model_type != "donut-swin":
raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
elif self.model.config.__class__.__name__ == "LayoutLMConfig":
self.model_type = ModelType.LayoutLM
else:
self.model_type = ModelType.LayoutLMv2andv3
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING)
if self.model.config.__class__.__name__ == "LayoutLMConfig":
self.model_type = ModelType.LayoutLM
else:
self.model_type = ModelType.LayoutLMv2andv3
def _sanitize_parameters(
self,