[bnb] add warning when no linear (#23894)

* add warning for gpt2-like models

* more details

* adapt from suggestions
This commit is contained in:
Younes Belkada 2023-05-31 16:40:07 +02:00 committed by GitHub
parent 8f915c450d
commit e42869b091
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 43 deletions

View File

@ -2687,7 +2687,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model = replace_with_bnb_linear( model = replace_with_bnb_linear(
model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config
) )
# training in 8-bit is only available in 0.37.0+ # training in 8-bit is only available in 0.37.0+
model._is_quantized_training_enabled = version.parse( model._is_quantized_training_enabled = version.parse(
importlib_metadata.version("bitsandbytes") importlib_metadata.version("bitsandbytes")
@ -2699,8 +2698,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if load_in_8bit and torch_dtype is None: if load_in_8bit and torch_dtype is None:
logger.warning( logger.warning(
"You are loading your model in 8bit but you did not specify a `torch_dtype` attribute." "You are loading your model in 8bit but you did not specify a `torch_dtype` attribute."
"All non-linear modules will be loaded in full precision.", "All non-linear modules will be loaded in full precision."
" If you want to load the other modules in other precision, please specify a `torch_dtype` attribute.", " If you want to load the other modules in other precision, please specify a `torch_dtype` attribute."
) )
if isinstance(device_map, str): if isinstance(device_map, str):

View File

@ -3,6 +3,7 @@ from copy import deepcopy
from packaging import version from packaging import version
from ..utils import logging
from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available
@ -15,6 +16,8 @@ if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters from accelerate.utils import find_tied_parameters
logger = logging.get_logger(__name__)
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None): def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
""" """
@ -106,6 +109,59 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
module._parameters[tensor_name] = new_value module._parameters[tensor_name] = new_value
def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
has_been_replaced = False
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if quantization_config.quantization_method() == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold,
)
has_been_replaced = True
else:
if (
quantization_config.llm_int8_skip_modules is not None
and name in quantization_config.llm_int8_skip_modules
):
pass
else:
model._modules[name] = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
)
has_been_replaced = True
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# Remove the last key for recursion
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
)
return model, has_been_replaced
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
""" """
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
@ -133,47 +189,18 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
`disk`). `disk`).
""" """
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
for name, module in model.named_children(): model, has_been_replaced = _replace_with_bnb_linear(
if current_key_name is None: model, modules_to_not_convert, current_key_name, quantization_config
current_key_name = [] )
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if quantization_config.quantization_method() == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold,
)
else:
if (
quantization_config.llm_int8_skip_modules is not None
and name in quantization_config.llm_int8_skip_modules
):
pass
else:
model._modules[name] = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# Remove the last key for recursion
if len(list(module.children())) > 0:
replace_with_bnb_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
)
return model return model