From 914289ac4b7994507fa7329bf6f54572b32ae061 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 26 Jun 2023 13:58:36 +0200 Subject: [PATCH] [`pipeline`] Fix str device issue (#24396) * fix str device issue * fixup * adapt from suggestions * forward contrib credits from suggestions * better fix * added backward compatibility for older PT versions * final fixes * oops * Attempting something with less branching. --------- Co-authored-by: amyeroberts Co-authored-by: Nicolas Patry --- src/transformers/pipelines/base.py | 7 ++++--- tests/pipelines/test_pipelines_common.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 7360f8b7f5a..ee117e62a18 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -903,9 +903,10 @@ class Pipeline(_ScikitCompat): yield else: if self.device.type == "cuda": - torch.cuda.set_device(self.device) - - yield + with torch.cuda.device(self.device): + yield + else: + yield def ensure_tensor_on_device(self, **inputs): """ diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 7909ad22d65..8c7c66939c3 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -46,6 +46,7 @@ from transformers.testing_utils import ( require_tensorflow_probability, require_tf, require_torch, + require_torch_gpu, require_torch_or_tf, slow, ) @@ -542,6 +543,20 @@ class PipelineUtilsTest(unittest.TestCase): gc.collect() torch.cuda.empty_cache() + @slow + @require_torch + @require_torch_gpu + def test_pipeline_cuda(self): + pipe = pipeline("text-generation", device="cuda") + _ = pipe("Hello") + + @slow + @require_torch + @require_torch_gpu + def test_pipeline_cuda_indexed(self): + pipe = pipeline("text-generation", device="cuda:0") + _ = pipe("Hello") + @slow @require_tf @require_tensorflow_probability