A few CI fixes for DocumentQuestionAnsweringPipeline (#19584)

* Fixes

* update expected values

* style

* fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Ankur Goyal 2022-10-17 06:35:27 -07:00 committed by GitHub
parent 0b7b07ef03
commit cbc1abc4af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 18 deletions

View File

@ -235,7 +235,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
`word_boxes`).
- **answer** (`str`) -- The answer to the question.
- **words** (`list[int]`) -- The index of each word/box pair that is in the answer
- **page** (`int`) -- The page of the answer
"""
if isinstance(question, str):
inputs = {"question": question, "image": image}
@ -315,7 +314,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
"p_mask": None,
"word_ids": None,
"words": None,
"page": None,
"output_attentions": True,
"is_last": True,
}
@ -339,6 +337,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
return_overflowing_tokens=True,
**tokenizer_kwargs,
)
encoding.pop("overflow_to_sample_mapping") # We do not use this
num_spans = len(encoding["input_ids"])
@ -395,9 +394,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
words = model_inputs.pop("words", None)
is_last = model_inputs.pop("is_last", False)
if "overflow_to_sample_mapping" in model_inputs:
model_inputs.pop("overflow_to_sample_mapping")
if self.model_type == ModelType.VisionEncoderDecoder:
model_outputs = self.model.generate(**model_inputs)
else:
@ -421,7 +417,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
return answers
def postprocess_encoder_decoder_single(self, model_outputs, **kwargs):
sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0]
sequence = self.tokenizer.batch_decode(model_outputs["sequences"])[0]
# TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
# (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context).

View File

@ -209,8 +209,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
{"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
],
)
@ -218,8 +218,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
{"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
],
)
@ -230,8 +230,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
{"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
]
]
* 2,
@ -320,8 +320,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
],
)
@ -332,8 +332,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
]
]
* 2,
@ -346,8 +346,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
{"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
],
)