mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
enable Pipeline to get device from model (#30534)
* check model.device * fix * style fix * move model device * remove print * add comment * fix * add unit test * optimize * change test names and add more cases * Update tests/pipelines/test_pipelines_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
f4dc26d466
commit
69d9bca55a
@ -845,6 +845,8 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
device = -1
|
||||
|
||||
if is_torch_available() and self.framework == "pt":
|
||||
if device == -1 and self.model.device is not None:
|
||||
device = self.model.device
|
||||
if isinstance(device, torch.device):
|
||||
if device.type == "xpu" and not is_torch_xpu_available(check_device=True):
|
||||
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
|
||||
@ -871,11 +873,10 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
||||
self.device = device if device is not None else -1
|
||||
|
||||
self.binary_output = binary_output
|
||||
|
||||
# We shouldn't call `model.to()` for models loaded with accelerate
|
||||
# We shouldn't call `model.to()` for models loaded with accelerate as well as the case that model is already on device
|
||||
if (
|
||||
self.framework == "pt"
|
||||
and self.device is not None
|
||||
and self.model.device != self.device
|
||||
and not (isinstance(self.device, int) and self.device < 0)
|
||||
and hf_device_map is None
|
||||
):
|
||||
|
@ -48,6 +48,7 @@ from transformers.testing_utils import (
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_or_tf,
|
||||
slow,
|
||||
torch_device,
|
||||
@ -519,6 +520,52 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
actual_output = classifier("Test input.")
|
||||
self.assertEqual(expected_output, actual_output)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_pipeline_no_device(self):
|
||||
# Test when no device is passed to pipeline
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# Case 1: Model is manually moved to device
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
model_device = model.device
|
||||
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
||||
self.assertEqual(pipe.model.device, model_device)
|
||||
# Case 2: Model is loaded by accelerate
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
model_device = model.device
|
||||
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
||||
self.assertEqual(pipe.model.device, model_device)
|
||||
# Case 3: device_map is passed to model and device is passed to pipeline
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
pipe = pipeline("text-generation", model=model, device="cpu", tokenizer=tokenizer)
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_pipeline_device_not_equal_model_device(self):
|
||||
# Test when device ids are different, pipeline should move the model to the passed device id
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
model_device = f"{torch_device}:1"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert", torch_dtype=torch.float16
|
||||
).to(model_device)
|
||||
target_device = f"{torch_device}:0"
|
||||
self.assertNotEqual(model_device, target_device)
|
||||
pipe = pipeline("text-generation", model=model, device=target_device, tokenizer=tokenizer)
|
||||
self.assertEqual(pipe.model.device, torch.device(target_device))
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_load_default_pipelines_pt(self):
|
||||
|
Loading…
Reference in New Issue
Block a user