[pipeline] revisit device check for pipeline (#25207)

* revisit device check for pipeline

* let's raise an error.
This commit is contained in:
Younes Belkada 2023-07-31 18:43:21 +02:00 committed by GitHub
parent 5220606607
commit e0c50b274a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -775,12 +775,20 @@ class Pipeline(_ScikitCompat):
self.modelcard = modelcard
self.framework = framework
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None and device is not None:
raise ValueError(
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
"discard the `device` argument when creating your pipeline object."
)
# We shouldn't call `model.to()` for models loaded with accelerate
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None:
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None:
# Take the first device used by `accelerate`.
device = next(iter(hf_device_map.values()))