diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 85d2e7c15d5..4e4fc2fc279 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -595,20 +595,27 @@ class QuestionAnsweringPipeline(ChunkPipeline): # - we start by finding the right word containing the token with `token_to_word` # - then we convert this word in a character span with `word_to_chars` sequence_index = 1 if question_first else 0 + for s, e, score in zip(starts, ends, scores): s = s - offset e = e - offset start_index, end_index = self.get_indices(enc, s, e, sequence_index, align_to_words) - answers.append( - { - "score": score.item(), - "start": start_index, - "end": end_index, - "answer": example.context_text[start_index:end_index], - } - ) + target_answer = example.context_text[start_index:end_index] + answer = self.get_answer(answers, target_answer) + + if answer: + answer["score"] += score.item() + else: + answers.append( + { + "score": score.item(), + "start": start_index, + "end": end_index, + "answer": example.context_text[start_index:end_index], + } + ) if handle_impossible_answer: answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""}) @@ -617,6 +624,12 @@ class QuestionAnsweringPipeline(ChunkPipeline): return answers[0] return answers + def get_answer(self, answers: List[Dict], target: str) -> Optional[Dict]: + for answer in answers: + if answer["answer"].lower() == target.lower(): + return answer + return None + def get_indices( self, enc: "tokenizers.Encoding", s: int, e: int, sequence_index: int, align_to_words: bool ) -> Tuple[int, int]: diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index a1c88254b7c..fbd70b2a099 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -138,7 +138,11 @@ class QAPipelineTests(unittest.TestCase): question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris.", top_k=20 ) self.assertEqual( - outputs, [{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)} for i in range(20)] + outputs, + [ + {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)} + for i in range(len(outputs)) + ], ) for single_output in outputs: compare_pipeline_output_to_hub_spec(single_output, QuestionAnsweringOutputElement) @@ -279,6 +283,19 @@ class QAPipelineTests(unittest.TestCase): ) self.assertEqual(nested_simplify(outputs), {"score": 0.988, "start": 0, "end": 0, "answer": ""}) + @require_torch + def test_duplicate_handling(self): + question_answerer = pipeline("question-answering", model="deepset/tinyroberta-squad2") + + outputs = question_answerer( + question="Who is the chancellor of Germany?", + context="Angela Merkel was the chancellor of Germany.", + top_k=10, + ) + + answers = [output["answer"] for output in outputs] + self.assertEqual(len(answers), len(set(answers)), "There are duplicate answers in the outputs.") + @require_tf def test_small_model_tf(self): question_answerer = pipeline(