* Fix of issue #2941

Reshaped score array to avoid `numpy` ValueError.

* Update src/transformers/pipelines.py

* Update src/transformers/pipelines.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Zhiyu Lin 2020-05-02 11:20:30 -04:00 committed by GitHub
parent 5f4f6b65b3
commit 1cdd2ad2af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -656,8 +656,8 @@ class TextClassificationPipeline(Pipeline):
def __call__(self, *args, **kwargs):
outputs = super().__call__(*args, **kwargs)
scores = np.exp(outputs) / np.exp(outputs).sum(-1)
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max()} for item in scores]
scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores]
class FillMaskPipeline(Pipeline):