mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[BugFix] QA pipeline edge case: align_to_words=True
in QuestionAnsweringPipeline
can lead to duplicate answers (#38761)
* fixing the problem align_to_words=True leading to duplicate solutions * adding tests * some fixes * some fixes * changing the handle_duplicate_answers=False by default * some fixese * some fixes * make the duplicate handling the default behaviour and merge duplicates * make the duplicate handling the default behaviour
This commit is contained in:
parent
18c7f32daa
commit
a7593a1d1f
@ -595,12 +595,19 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
# - we start by finding the right word containing the token with `token_to_word`
|
# - we start by finding the right word containing the token with `token_to_word`
|
||||||
# - then we convert this word in a character span with `word_to_chars`
|
# - then we convert this word in a character span with `word_to_chars`
|
||||||
sequence_index = 1 if question_first else 0
|
sequence_index = 1 if question_first else 0
|
||||||
|
|
||||||
for s, e, score in zip(starts, ends, scores):
|
for s, e, score in zip(starts, ends, scores):
|
||||||
s = s - offset
|
s = s - offset
|
||||||
e = e - offset
|
e = e - offset
|
||||||
|
|
||||||
start_index, end_index = self.get_indices(enc, s, e, sequence_index, align_to_words)
|
start_index, end_index = self.get_indices(enc, s, e, sequence_index, align_to_words)
|
||||||
|
|
||||||
|
target_answer = example.context_text[start_index:end_index]
|
||||||
|
answer = self.get_answer(answers, target_answer)
|
||||||
|
|
||||||
|
if answer:
|
||||||
|
answer["score"] += score.item()
|
||||||
|
else:
|
||||||
answers.append(
|
answers.append(
|
||||||
{
|
{
|
||||||
"score": score.item(),
|
"score": score.item(),
|
||||||
@ -617,6 +624,12 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
return answers[0]
|
return answers[0]
|
||||||
return answers
|
return answers
|
||||||
|
|
||||||
|
def get_answer(self, answers: List[Dict], target: str) -> Optional[Dict]:
|
||||||
|
for answer in answers:
|
||||||
|
if answer["answer"].lower() == target.lower():
|
||||||
|
return answer
|
||||||
|
return None
|
||||||
|
|
||||||
def get_indices(
|
def get_indices(
|
||||||
self, enc: "tokenizers.Encoding", s: int, e: int, sequence_index: int, align_to_words: bool
|
self, enc: "tokenizers.Encoding", s: int, e: int, sequence_index: int, align_to_words: bool
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
|
@ -138,7 +138,11 @@ class QAPipelineTests(unittest.TestCase):
|
|||||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris.", top_k=20
|
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris.", top_k=20
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs, [{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)} for i in range(20)]
|
outputs,
|
||||||
|
[
|
||||||
|
{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)}
|
||||||
|
for i in range(len(outputs))
|
||||||
|
],
|
||||||
)
|
)
|
||||||
for single_output in outputs:
|
for single_output in outputs:
|
||||||
compare_pipeline_output_to_hub_spec(single_output, QuestionAnsweringOutputElement)
|
compare_pipeline_output_to_hub_spec(single_output, QuestionAnsweringOutputElement)
|
||||||
@ -279,6 +283,19 @@ class QAPipelineTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.988, "start": 0, "end": 0, "answer": ""})
|
self.assertEqual(nested_simplify(outputs), {"score": 0.988, "start": 0, "end": 0, "answer": ""})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_duplicate_handling(self):
|
||||||
|
question_answerer = pipeline("question-answering", model="deepset/tinyroberta-squad2")
|
||||||
|
|
||||||
|
outputs = question_answerer(
|
||||||
|
question="Who is the chancellor of Germany?",
|
||||||
|
context="Angela Merkel was the chancellor of Germany.",
|
||||||
|
top_k=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
answers = [output["answer"] for output in outputs]
|
||||||
|
self.assertEqual(len(answers), len(set(answers)), "There are duplicate answers in the outputs.")
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
question_answerer = pipeline(
|
question_answerer = pipeline(
|
||||||
|
Loading…
Reference in New Issue
Block a user