mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[pipeline
] revisit device check for pipeline (#25207)
* revisit device check for pipeline * let's raise an error.
This commit is contained in:
parent
5220606607
commit
e0c50b274a
@ -775,12 +775,20 @@ class Pipeline(_ScikitCompat):
|
||||
self.modelcard = modelcard
|
||||
self.framework = framework
|
||||
|
||||
# `accelerate` device map
|
||||
hf_device_map = getattr(self.model, "hf_device_map", None)
|
||||
|
||||
if hf_device_map is not None and device is not None:
|
||||
raise ValueError(
|
||||
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
|
||||
"discard the `device` argument when creating your pipeline object."
|
||||
)
|
||||
|
||||
# We shouldn't call `model.to()` for models loaded with accelerate
|
||||
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
|
||||
self.model.to(device)
|
||||
|
||||
if device is None:
|
||||
# `accelerate` device map
|
||||
hf_device_map = getattr(self.model, "hf_device_map", None)
|
||||
if hf_device_map is not None:
|
||||
# Take the first device used by `accelerate`.
|
||||
device = next(iter(hf_device_map.values()))
|
||||
|
Loading…
Reference in New Issue
Block a user