mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
* Fix of issue #2941 Reshaped score array to avoid `numpy` ValueError. * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
5f4f6b65b3
commit
1cdd2ad2af
@ -656,8 +656,8 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
outputs = super().__call__(*args, **kwargs)
|
outputs = super().__call__(*args, **kwargs)
|
||||||
scores = np.exp(outputs) / np.exp(outputs).sum(-1)
|
scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
|
||||||
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max()} for item in scores]
|
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores]
|
||||||
|
|
||||||
|
|
||||||
class FillMaskPipeline(Pipeline):
|
class FillMaskPipeline(Pipeline):
|
||||||
|
Loading…
Reference in New Issue
Block a user