[BitsandBytes] Verify if GPU is available (#30533)

Change order
This commit is contained in:
NielsRogge 2024-05-08 12:42:58 +02:00 committed by GitHub
parent 998dbe068b
commit 1872bde7fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 6 deletions

View File

@ -58,6 +58,8 @@ class Bnb4BitHfQuantizer(HfQuantizer):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
@ -70,9 +72,6 @@ class Bnb4BitHfQuantizer(HfQuantizer):
" sure the weights are in PyTorch format."
)
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
device_map = kwargs.get("device_map", None)
if (
device_map is not None

View File

@ -58,6 +58,9 @@ class Bnb8BitHfQuantizer(HfQuantizer):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
@ -70,9 +73,6 @@ class Bnb8BitHfQuantizer(HfQuantizer):
" sure the weights are in PyTorch format."
)
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
device_map = kwargs.get("device_map", None)
if (
device_map is not None