Fixing Backward compatiblity for zero-shot (#13855)

Fixes #13846
This commit is contained in:
Nicolas Patry 2021-10-06 05:06:47 +02:00 committed by GitHub
parent 9f58becc8d
commit 013bdc6d65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 4 deletions

View File

@ -191,10 +191,7 @@ class ZeroShotClassificationPipeline(Pipeline):
else:
raise ValueError(f"Unable to understand extra arguments {args}")
result = super().__call__(sequences, **kwargs)
if len(result) == 1:
return result[0]
return result
return super().__call__(sequences, **kwargs)
def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
@ -264,4 +261,6 @@ class ZeroShotClassificationPipeline(Pipeline):
"scores": scores[iseq, top_inds].tolist(),
}
)
if len(result) == 1:
return result[0]
return result

View File

@ -61,6 +61,24 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
)
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
# https://github.com/huggingface/transformers/issues/13846
outputs = classifier(["I am happy"], ["positive", "negative"])
self.assertEqual(
outputs,
[
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
for i in range(1)
],
)
outputs = classifier(["I am happy", "I am sad"], ["positive", "negative"])
self.assertEqual(
outputs,
[
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
for i in range(2)
],
)
with self.assertRaises(ValueError):
classifier("", candidate_labels="politics")