mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix QA argument handler (#8765)
* Fix QA argument handler * Attempt to get a better fix for QA (#8768) Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
4821ea5aeb
commit
138f45c184
@ -1624,7 +1624,17 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
elif "data" in kwargs:
|
||||
inputs = kwargs["data"]
|
||||
elif "question" in kwargs and "context" in kwargs:
|
||||
inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
|
||||
if isinstance(kwargs["question"], list) and isinstance(kwargs["context"], str):
|
||||
inputs = [{"question": Q, "context": kwargs["context"]} for Q in kwargs["question"]]
|
||||
elif isinstance(kwargs["question"], list) and isinstance(kwargs["context"], list):
|
||||
if len(kwargs["question"]) != len(kwargs["context"]):
|
||||
raise ValueError("Questions and contexts don't have the same lengths")
|
||||
|
||||
inputs = [{"question": Q, "context": C} for Q, C in zip(kwargs["question"], kwargs["context"])]
|
||||
elif isinstance(kwargs["question"], str) and isinstance(kwargs["context"], str):
|
||||
inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
|
||||
else:
|
||||
raise ValueError("Arguments can't be understood")
|
||||
else:
|
||||
raise ValueError("Unknown arguments {}".format(kwargs))
|
||||
|
||||
|
@ -23,6 +23,17 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
"question": "In what field is HuggingFace working ?",
|
||||
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||
},
|
||||
{
|
||||
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
|
||||
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||
},
|
||||
{
|
||||
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
|
||||
"context": [
|
||||
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def _test_pipeline(self, nlp: Pipeline):
|
||||
@ -80,6 +91,11 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
self.assertEqual(len(normalized), 1)
|
||||
self.assertEqual({type(el) for el in normalized}, {SquadExample})
|
||||
|
||||
normalized = qa(question=[Q, Q], context=C)
|
||||
self.assertEqual(type(normalized), list)
|
||||
self.assertEqual(len(normalized), 2)
|
||||
self.assertEqual({type(el) for el in normalized}, {SquadExample})
|
||||
|
||||
normalized = qa({"question": Q, "context": C})
|
||||
self.assertEqual(type(normalized), list)
|
||||
self.assertEqual(len(normalized), 1)
|
||||
@ -159,6 +175,26 @@ class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
qa([{"question": Q, "context": C}, {"question": Q, "context": ""}])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
qa(question={"This": "Is weird"}, context="This is a context")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
qa(question=[Q, Q], context=[C, C, C])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
qa(question=[Q, Q, Q], context=[C, C])
|
||||
|
||||
def test_argument_handler_old_format(self):
|
||||
qa = QuestionAnsweringArgumentHandler()
|
||||
|
||||
Q = "Where was HuggingFace founded ?"
|
||||
C = "HuggingFace was founded in Paris"
|
||||
# Backward compatibility for this
|
||||
normalized = qa(question=[Q, Q], context=[C, C])
|
||||
self.assertEqual(type(normalized), list)
|
||||
self.assertEqual(len(normalized), 2)
|
||||
self.assertEqual({type(el) for el in normalized}, {SquadExample})
|
||||
|
||||
def test_argument_handler_error_handling_odd(self):
|
||||
qa = QuestionAnsweringArgumentHandler()
|
||||
with self.assertRaises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user