From cbc1abc4affdd1ac6fc327fcd243a4dd6752343b Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Mon, 17 Oct 2022 06:35:27 -0700 Subject: [PATCH] A few CI fixes for `DocumentQuestionAnsweringPipeline` (#19584) * Fixes * update expected values * style * fix Co-authored-by: ydshieh --- .../pipelines/document_question_answering.py | 8 ++----- ...t_pipelines_document_question_answering.py | 24 +++++++++---------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py index 1b14c1f4801..a0389e013bd 100644 --- a/src/transformers/pipelines/document_question_answering.py +++ b/src/transformers/pipelines/document_question_answering.py @@ -235,7 +235,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): `word_boxes`). - **answer** (`str`) -- The answer to the question. - **words** (`list[int]`) -- The index of each word/box pair that is in the answer - - **page** (`int`) -- The page of the answer """ if isinstance(question, str): inputs = {"question": question, "image": image} @@ -315,7 +314,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): "p_mask": None, "word_ids": None, "words": None, - "page": None, "output_attentions": True, "is_last": True, } @@ -339,6 +337,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): return_overflowing_tokens=True, **tokenizer_kwargs, ) + encoding.pop("overflow_to_sample_mapping") # We do not use this num_spans = len(encoding["input_ids"]) @@ -395,9 +394,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): words = model_inputs.pop("words", None) is_last = model_inputs.pop("is_last", False) - if "overflow_to_sample_mapping" in model_inputs: - model_inputs.pop("overflow_to_sample_mapping") - if self.model_type == ModelType.VisionEncoderDecoder: model_outputs = self.model.generate(**model_inputs) else: @@ -421,7 +417,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline): return answers def postprocess_encoder_decoder_single(self, model_outputs, **kwargs): - sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0] + sequence = self.tokenizer.batch_decode(model_outputs["sequences"])[0] # TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer # (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context). diff --git a/tests/pipelines/test_pipelines_document_question_answering.py b/tests/pipelines/test_pipelines_document_question_answering.py index fa272d64921..c73decda0a4 100644 --- a/tests/pipelines/test_pipelines_document_question_answering.py +++ b/tests/pipelines/test_pipelines_document_question_answering.py @@ -209,8 +209,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli self.assertEqual( nested_simplify(outputs, decimals=4), [ - {"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22}, - {"score": 0.996, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23}, + {"score": 0.9948, "answer": "us-001", "start": 16, "end": 16}, ], ) @@ -218,8 +218,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli self.assertEqual( nested_simplify(outputs, decimals=4), [ - {"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22}, - {"score": 0.996, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23}, + {"score": 0.9948, "answer": "us-001", "start": 16, "end": 16}, ], ) @@ -230,8 +230,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli nested_simplify(outputs, decimals=4), [ [ - {"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22}, - {"score": 0.996, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23}, + {"score": 0.9948, "answer": "us-001", "start": 16, "end": 16}, ] ] * 2, @@ -320,8 +320,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli self.assertEqual( nested_simplify(outputs, decimals=4), [ - {"score": 0.9999, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.9924, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9999, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.9998, "answer": "us-001", "start": 16, "end": 16}, ], ) @@ -332,8 +332,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli nested_simplify(outputs, decimals=4), [ [ - {"score": 0.9999, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.9924, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9999, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.9998, "answer": "us-001", "start": 16, "end": 16}, ] ] * 2, @@ -346,8 +346,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli self.assertEqual( nested_simplify(outputs, decimals=4), [ - {"score": 0.9999, "answer": "us-001", "start": 15, "end": 15}, - {"score": 0.9924, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9999, "answer": "us-001", "start": 16, "end": 16}, + {"score": 0.9998, "answer": "us-001", "start": 16, "end": 16}, ], )