mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[pipeline
] Fix str device issue (#24396)
* fix str device issue * fixup * adapt from suggestions * forward contrib credits from suggestions * better fix * added backward compatibility for older PT versions * final fixes * oops * Attempting something with less branching. --------- Co-authored-by: amyeroberts <amyeroberts@users.noreply.github.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
892399c5ff
commit
914289ac4b
@ -903,9 +903,10 @@ class Pipeline(_ScikitCompat):
|
|||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
if self.device.type == "cuda":
|
if self.device.type == "cuda":
|
||||||
torch.cuda.set_device(self.device)
|
with torch.cuda.device(self.device):
|
||||||
|
yield
|
||||||
yield
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
def ensure_tensor_on_device(self, **inputs):
|
def ensure_tensor_on_device(self, **inputs):
|
||||||
"""
|
"""
|
||||||
|
@ -46,6 +46,7 @@ from transformers.testing_utils import (
|
|||||||
require_tensorflow_probability,
|
require_tensorflow_probability,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
require_torch_or_tf,
|
require_torch_or_tf,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
@ -542,6 +543,20 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_pipeline_cuda(self):
|
||||||
|
pipe = pipeline("text-generation", device="cuda")
|
||||||
|
_ = pipe("Hello")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_pipeline_cuda_indexed(self):
|
||||||
|
pipe = pipeline("text-generation", device="cuda:0")
|
||||||
|
_ = pipe("Hello")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_tf
|
@require_tf
|
||||||
@require_tensorflow_probability
|
@require_tensorflow_probability
|
||||||
|
Loading…
Reference in New Issue
Block a user