Accepting real pytorch device as arguments. (#17318)

* Accepting real pytorch device as arguments.

* is_torch_available.
This commit is contained in:
Nicolas Patry 2022-05-18 16:06:24 +02:00 committed by GitHub
parent 1c9d1f4ca8
commit 2cb2ea3fa1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 2 deletions

View File

@ -693,7 +693,7 @@ PIPELINE_INIT_ARGS = r"""
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id.
the associated CUDA device id. You can pass native `torch.device` too.
binary_output (`bool`, *optional*, defaults to `False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
@ -750,7 +750,10 @@ class Pipeline(_ScikitCompat):
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
if is_torch_available() and isinstance(device, torch.device):
self.device = device
else:
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
self.binary_output = binary_output
# Special handling

View File

@ -39,6 +39,20 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_torch
def test_accepts_torch_device(self):
import torch
text_classifier = pipeline(
task="text-classification",
model="hf-internal-testing/tiny-random-distilbert",
framework="pt",
device=torch.device("cpu"),
)
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])
@require_tf
def test_small_model_tf(self):
text_classifier = pipeline(