mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* Fix of issue #2941 Reshaped score array to avoid `numpy` ValueError. * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
5f4f6b65b3
commit
1cdd2ad2af
@ -656,8 +656,8 @@ class TextClassificationPipeline(Pipeline):
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
outputs = super().__call__(*args, **kwargs)
|
||||
scores = np.exp(outputs) / np.exp(outputs).sum(-1)
|
||||
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max()} for item in scores]
|
||||
scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
|
||||
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores]
|
||||
|
||||
|
||||
class FillMaskPipeline(Pipeline):
|
||||
|
Loading…
Reference in New Issue
Block a user