mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Running a pipeline of float16
. (#17637)
When we're preparing the tensors for CPU for postprocessing, we need to upgrade the `float16` to `float32` since CPUs don't have instructions for `[b]float16`.
This commit is contained in:
parent
90ed9ae2d1
commit
c38f4e1f1c
@ -869,6 +869,8 @@ class Pipeline(_ScikitCompat):
|
||||
elif isinstance(inputs, tuple):
|
||||
return tuple([self._ensure_tensor_on_device(item, device) for item in inputs])
|
||||
elif isinstance(inputs, torch.Tensor):
|
||||
if device == torch.device("cpu") and inputs.dtype in {torch.float16, torch.bfloat16}:
|
||||
inputs = inputs.float()
|
||||
return inputs.to(device)
|
||||
else:
|
||||
return inputs
|
||||
|
@ -16,7 +16,14 @@ import unittest
|
||||
|
||||
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
|
||||
from transformers.pipelines import PipelineException
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
@ -130,6 +137,19 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_fp16_casting(self):
|
||||
pipe = pipeline("fill-mask", model="hf-internal-testing/tiny-random-distilbert", device=0, framework="pt")
|
||||
|
||||
# convert model to fp16
|
||||
pipe.model.half()
|
||||
|
||||
response = pipe("Paris is the [MASK] of France.")
|
||||
# We actually don't care about the result, we just want to make sure
|
||||
# it works, meaning the float16 tensor got casted back to float32
|
||||
# for postprocessing.
|
||||
self.assertIsInstance(response, list)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
|
Loading…
Reference in New Issue
Block a user