diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 86655192637..e871942ce92 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1033,7 +1033,7 @@ class Pipeline(_ScikitCompat, PushToHubMixin): else: self.device = device if device is not None else -1 - if is_torch_available() and torch.distributed.is_initialized(): + if is_torch_available() and torch.distributed.is_available() and torch.distributed.is_initialized(): self.device = self.model.device logger.warning(f"Device set to use {self.device}")