mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
9f58becc8d
commit
013bdc6d65
@ -191,10 +191,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
else:
|
||||
raise ValueError(f"Unable to understand extra arguments {args}")
|
||||
|
||||
result = super().__call__(sequences, **kwargs)
|
||||
if len(result) == 1:
|
||||
return result[0]
|
||||
return result
|
||||
return super().__call__(sequences, **kwargs)
|
||||
|
||||
def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
|
||||
sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
|
||||
@ -264,4 +261,6 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
"scores": scores[iseq, top_inds].tolist(),
|
||||
}
|
||||
)
|
||||
if len(result) == 1:
|
||||
return result[0]
|
||||
return result
|
||||
|
@ -61,6 +61,24 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
|
||||
)
|
||||
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
|
||||
|
||||
# https://github.com/huggingface/transformers/issues/13846
|
||||
outputs = classifier(["I am happy"], ["positive", "negative"])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
|
||||
for i in range(1)
|
||||
],
|
||||
)
|
||||
outputs = classifier(["I am happy", "I am sad"], ["positive", "negative"])
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
|
||||
for i in range(2)
|
||||
],
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
classifier("", candidate_labels="politics")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user