mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix bf16 issue in text classification pipeline (#30996)
* fix logits dtype * Add bf16/fp16 tests for text_classification pipeline * Update test_pipelines_text_classification.py * fix * fix
This commit is contained in:
parent
de460e28e1
commit
6b22a8f2d8
@ -202,7 +202,7 @@ class TextClassificationPipeline(Pipeline):
|
||||
function_to_apply = ClassificationFunction.NONE
|
||||
|
||||
outputs = model_outputs["logits"][0]
|
||||
outputs = outputs.numpy()
|
||||
outputs = outputs.float().numpy()
|
||||
|
||||
if function_to_apply == ClassificationFunction.SIGMOID:
|
||||
scores = sigmoid(outputs)
|
||||
|
@ -14,13 +14,24 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TextClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_bf16,
|
||||
require_torch_fp16,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
@ -106,6 +117,32 @@ class TextClassificationPipelineTests(unittest.TestCase):
|
||||
outputs = text_classifier("This is great !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||
|
||||
@require_torch_fp16
|
||||
def test_accepts_torch_fp16(self):
|
||||
text_classifier = pipeline(
|
||||
task="text-classification",
|
||||
model="hf-internal-testing/tiny-random-distilbert",
|
||||
framework="pt",
|
||||
device=torch_device,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
outputs = text_classifier("This is great !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||
|
||||
@require_torch_bf16
|
||||
def test_accepts_torch_bf16(self):
|
||||
text_classifier = pipeline(
|
||||
task="text-classification",
|
||||
model="hf-internal-testing/tiny-random-distilbert",
|
||||
framework="pt",
|
||||
device=torch_device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
outputs = text_classifier("This is great !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
text_classifier = pipeline(
|
||||
|
Loading…
Reference in New Issue
Block a user