mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
A few CI fixes for DocumentQuestionAnsweringPipeline
(#19584)
* Fixes * update expected values * style * fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
0b7b07ef03
commit
cbc1abc4af
@ -235,7 +235,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
||||
`word_boxes`).
|
||||
- **answer** (`str`) -- The answer to the question.
|
||||
- **words** (`list[int]`) -- The index of each word/box pair that is in the answer
|
||||
- **page** (`int`) -- The page of the answer
|
||||
"""
|
||||
if isinstance(question, str):
|
||||
inputs = {"question": question, "image": image}
|
||||
@ -315,7 +314,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
||||
"p_mask": None,
|
||||
"word_ids": None,
|
||||
"words": None,
|
||||
"page": None,
|
||||
"output_attentions": True,
|
||||
"is_last": True,
|
||||
}
|
||||
@ -339,6 +337,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
||||
return_overflowing_tokens=True,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
encoding.pop("overflow_to_sample_mapping") # We do not use this
|
||||
|
||||
num_spans = len(encoding["input_ids"])
|
||||
|
||||
@ -395,9 +394,6 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
||||
words = model_inputs.pop("words", None)
|
||||
is_last = model_inputs.pop("is_last", False)
|
||||
|
||||
if "overflow_to_sample_mapping" in model_inputs:
|
||||
model_inputs.pop("overflow_to_sample_mapping")
|
||||
|
||||
if self.model_type == ModelType.VisionEncoderDecoder:
|
||||
model_outputs = self.model.generate(**model_inputs)
|
||||
else:
|
||||
@ -421,7 +417,7 @@ class DocumentQuestionAnsweringPipeline(ChunkPipeline):
|
||||
return answers
|
||||
|
||||
def postprocess_encoder_decoder_single(self, model_outputs, **kwargs):
|
||||
sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0]
|
||||
sequence = self.tokenizer.batch_decode(model_outputs["sequences"])[0]
|
||||
|
||||
# TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
|
||||
# (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context).
|
||||
|
@ -209,8 +209,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
|
||||
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
|
||||
{"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
|
||||
],
|
||||
)
|
||||
|
||||
@ -218,8 +218,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
|
||||
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
|
||||
{"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
|
||||
],
|
||||
)
|
||||
|
||||
@ -230,8 +230,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
|
||||
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9974, "answer": "1110212019", "start": 23, "end": 23},
|
||||
{"score": 0.9948, "answer": "us-001", "start": 16, "end": 16},
|
||||
]
|
||||
]
|
||||
* 2,
|
||||
@ -320,8 +320,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
|
||||
{"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
|
||||
],
|
||||
)
|
||||
|
||||
@ -332,8 +332,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
[
|
||||
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
|
||||
{"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
|
||||
]
|
||||
]
|
||||
* 2,
|
||||
@ -346,8 +346,8 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=4),
|
||||
[
|
||||
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
|
||||
{"score": 0.9999, "answer": "us-001", "start": 16, "end": 16},
|
||||
{"score": 0.9998, "answer": "us-001", "start": 16, "end": 16},
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user