mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Move the model type check (#19027)
Co-authored-by: Ankur Goyal <ankur@impira.com>
This commit is contained in:
parent
ea75e9f10e
commit
216b2f9e80
@ -116,16 +116,17 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING)
|
||||
|
||||
if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig":
|
||||
self.model_type = ModelType.VisionEncoderDecoder
|
||||
if self.model.config.encoder.model_type != "donut-swin":
|
||||
raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
|
||||
elif self.model.config.__class__.__name__ == "LayoutLMConfig":
|
||||
self.model_type = ModelType.LayoutLM
|
||||
else:
|
||||
self.model_type = ModelType.LayoutLMv2andv3
|
||||
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING)
|
||||
if self.model.config.__class__.__name__ == "LayoutLMConfig":
|
||||
self.model_type = ModelType.LayoutLM
|
||||
else:
|
||||
self.model_type = ModelType.LayoutLMv2andv3
|
||||
|
||||
def _sanitize_parameters(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user