fix bf16 issue in text classification pipeline (#30996)

* fix logits dtype

* Add bf16/fp16 tests for text_classification pipeline

* Update test_pipelines_text_classification.py

* fix

* fix
This commit is contained in:
Chujie Zheng 2024-06-04 03:20:48 -07:00 committed by GitHub
parent de460e28e1
commit 6b22a8f2d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 2 deletions

View File

@ -202,7 +202,7 @@ class TextClassificationPipeline(Pipeline):
function_to_apply = ClassificationFunction.NONE
outputs = model_outputs["logits"][0]
outputs = outputs.numpy()
outputs = outputs.float().numpy()
if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs)

View File

@ -14,13 +14,24 @@
import unittest
import torch
from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TextClassificationPipeline,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow, torch_device
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_torch_bf16,
require_torch_fp16,
slow,
torch_device,
)
from .test_pipelines_common import ANY
@ -106,6 +117,32 @@ class TextClassificationPipelineTests(unittest.TestCase):
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_torch_fp16
def test_accepts_torch_fp16(self):
text_classifier = pipeline(
task="text-classification",
model="hf-internal-testing/tiny-random-distilbert",
framework="pt",
device=torch_device,
torch_dtype=torch.float16,
)
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_torch_bf16
def test_accepts_torch_bf16(self):
text_classifier = pipeline(
task="text-classification",
model="hf-internal-testing/tiny-random-distilbert",
framework="pt",
device=torch_device,
torch_dtype=torch.bfloat16,
)
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_tf
def test_small_model_tf(self):
text_classifier = pipeline(