Fix QA argument handler (#8765)

* Fix QA argument handler

* Attempt to get a better fix for QA (#8768)

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Lysandre Debut 2020-11-25 14:02:15 -05:00 committed by GitHub
parent 4821ea5aeb
commit 138f45c184
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 1 deletions

View File

@ -1624,7 +1624,17 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
elif "data" in kwargs:
inputs = kwargs["data"]
elif "question" in kwargs and "context" in kwargs:
inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
if isinstance(kwargs["question"], list) and isinstance(kwargs["context"], str):
inputs = [{"question": Q, "context": kwargs["context"]} for Q in kwargs["question"]]
elif isinstance(kwargs["question"], list) and isinstance(kwargs["context"], list):
if len(kwargs["question"]) != len(kwargs["context"]):
raise ValueError("Questions and contexts don't have the same lengths")
inputs = [{"question": Q, "context": C} for Q, C in zip(kwargs["question"], kwargs["context"])]
elif isinstance(kwargs["question"], str) and isinstance(kwargs["context"], str):
inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
else:
raise ValueError("Arguments can't be understood")
else:
raise ValueError("Unknown arguments {}".format(kwargs))

View File

@ -23,6 +23,17 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
"question": "In what field is HuggingFace working ?",
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
},
{
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
},
{
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
"context": [
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
],
},
]
def _test_pipeline(self, nlp: Pipeline):
@ -80,6 +91,11 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
self.assertEqual(len(normalized), 1)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa(question=[Q, Q], context=C)
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 2)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
normalized = qa({"question": Q, "context": C})
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 1)
@ -159,6 +175,26 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
with self.assertRaises(ValueError):
qa([{"question": Q, "context": C}, {"question": Q, "context": ""}])
with self.assertRaises(ValueError):
qa(question={"This": "Is weird"}, context="This is a context")
with self.assertRaises(ValueError):
qa(question=[Q, Q], context=[C, C, C])
with self.assertRaises(ValueError):
qa(question=[Q, Q, Q], context=[C, C])
def test_argument_handler_old_format(self):
qa = QuestionAnsweringArgumentHandler()
Q = "Where was HuggingFace founded ?"
C = "HuggingFace was founded in Paris"
# Backward compatibility for this
normalized = qa(question=[Q, Q], context=[C, C])
self.assertEqual(type(normalized), list)
self.assertEqual(len(normalized), 2)
self.assertEqual({type(el) for el in normalized}, {SquadExample})
def test_argument_handler_error_handling_odd(self):
qa = QuestionAnsweringArgumentHandler()
with self.assertRaises(ValueError):