mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
fix bf16 issue in text classification pipeline (#30996)
* fix logits dtype * Add bf16/fp16 tests for text_classification pipeline * Update test_pipelines_text_classification.py * fix * fix
This commit is contained in:
parent
de460e28e1
commit
6b22a8f2d8
@ -202,7 +202,7 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
function_to_apply = ClassificationFunction.NONE
|
function_to_apply = ClassificationFunction.NONE
|
||||||
|
|
||||||
outputs = model_outputs["logits"][0]
|
outputs = model_outputs["logits"][0]
|
||||||
outputs = outputs.numpy()
|
outputs = outputs.float().numpy()
|
||||||
|
|
||||||
if function_to_apply == ClassificationFunction.SIGMOID:
|
if function_to_apply == ClassificationFunction.SIGMOID:
|
||||||
scores = sigmoid(outputs)
|
scores = sigmoid(outputs)
|
||||||
|
@ -14,13 +14,24 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
nested_simplify,
|
||||||
|
require_tf,
|
||||||
|
require_torch,
|
||||||
|
require_torch_bf16,
|
||||||
|
require_torch_fp16,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
@ -106,6 +117,32 @@ class TextClassificationPipelineTests(unittest.TestCase):
|
|||||||
outputs = text_classifier("This is great !")
|
outputs = text_classifier("This is great !")
|
||||||
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||||
|
|
||||||
|
@require_torch_fp16
|
||||||
|
def test_accepts_torch_fp16(self):
|
||||||
|
text_classifier = pipeline(
|
||||||
|
task="text-classification",
|
||||||
|
model="hf-internal-testing/tiny-random-distilbert",
|
||||||
|
framework="pt",
|
||||||
|
device=torch_device,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = text_classifier("This is great !")
|
||||||
|
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||||
|
|
||||||
|
@require_torch_bf16
|
||||||
|
def test_accepts_torch_bf16(self):
|
||||||
|
text_classifier = pipeline(
|
||||||
|
task="text-classification",
|
||||||
|
model="hf-internal-testing/tiny-random-distilbert",
|
||||||
|
framework="pt",
|
||||||
|
device=torch_device,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = text_classifier("This is great !")
|
||||||
|
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
text_classifier = pipeline(
|
text_classifier = pipeline(
|
||||||
|
Loading…
Reference in New Issue
Block a user