Improve error msg when using bitsandbytes (#31350)

improve error msg when using bnb
This commit is contained in:
Marc Sun 2024-06-10 14:22:14 +02:00 committed by GitHub
parent 517df566f5
commit dc6eb44841
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 6 deletions

View File

@ -60,10 +60,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
def validate_environment(self, *args, **kwargs): def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.") raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not (is_accelerate_available() and is_bitsandbytes_available()): if not is_accelerate_available():
raise ImportError("Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install accelerate`")
if not is_bitsandbytes_available():
raise ImportError( raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` " "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
"and the latest version of bitsandbytes: `pip install -U bitsandbytes`"
) )
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):

View File

@ -61,10 +61,11 @@ class Bnb8BitHfQuantizer(HfQuantizer):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.") raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not (is_accelerate_available() and is_bitsandbytes_available()): if not is_accelerate_available():
raise ImportError("Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate`")
if not is_bitsandbytes_available():
raise ImportError( raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` " "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
"and the latest version of bitsandbytes: `pip install -U bitsandbytes`"
) )
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):