Better error message for bitsandbytes import (#27764)

* better error message

* fix logic

* fix log
This commit is contained in:
Marc Sun 2023-12-01 11:59:14 -05:00 committed by GitHub
parent 7b6324e18e
commit abd4cbd775
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1010,7 +1010,7 @@ class ModuleUtilsMixin:
else:
raise ValueError(
"bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
" make sure to install bitsandbytes with `pip install bitsandbytes`."
" make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
)
for param in total_parameters:
@ -2746,11 +2746,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
if load_in_8bit or load_in_4bit:
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
"Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of"
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
" pip install bitsandbytes` "
" `pip install bitsandbytes`."
)
if torch_dtype is None:
@ -2764,10 +2766,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch_dtype = torch.float16
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
device_map = {"": torch.cuda.current_device()}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {'':torch.cuda.current_device()}. "