[pipeline] Fix str device issue (#24396)

* fix str device issue

* fixup

* adapt from suggestions

* forward contrib credits from suggestions

* better fix

* added backward compatibility for older PT versions

* final fixes

* oops

* Attempting something with less branching.

---------

Co-authored-by: amyeroberts <amyeroberts@users.noreply.github.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Younes Belkada 2023-06-26 13:58:36 +02:00 committed by GitHub
parent 892399c5ff
commit 914289ac4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 3 deletions

View File

@ -903,9 +903,10 @@ class Pipeline(_ScikitCompat):
yield yield
else: else:
if self.device.type == "cuda": if self.device.type == "cuda":
torch.cuda.set_device(self.device) with torch.cuda.device(self.device):
yield
yield else:
yield
def ensure_tensor_on_device(self, **inputs): def ensure_tensor_on_device(self, **inputs):
""" """

View File

@ -46,6 +46,7 @@ from transformers.testing_utils import (
require_tensorflow_probability, require_tensorflow_probability,
require_tf, require_tf,
require_torch, require_torch,
require_torch_gpu,
require_torch_or_tf, require_torch_or_tf,
slow, slow,
) )
@ -542,6 +543,20 @@ class PipelineUtilsTest(unittest.TestCase):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@slow
@require_torch
@require_torch_gpu
def test_pipeline_cuda(self):
pipe = pipeline("text-generation", device="cuda")
_ = pipe("Hello")
@slow
@require_torch
@require_torch_gpu
def test_pipeline_cuda_indexed(self):
pipe = pipeline("text-generation", device="cuda:0")
_ = pipe("Hello")
@slow @slow
@require_tf @require_tf
@require_tensorflow_probability @require_tensorflow_probability