From e42869b091c4bba9f5b2007196d4adceb54d3b28 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 31 May 2023 16:40:07 +0200 Subject: [PATCH] [`bnb`] add warning when no linear (#23894) * add warning for gpt2-like models * more details * adapt from suggestions --- src/transformers/modeling_utils.py | 5 +- src/transformers/utils/bitsandbytes.py | 107 ++++++++++++++++--------- 2 files changed, 69 insertions(+), 43 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 588f7f6134b..aa982873664 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2687,7 +2687,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model = replace_with_bnb_linear( model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config ) - # training in 8-bit is only available in 0.37.0+ model._is_quantized_training_enabled = version.parse( importlib_metadata.version("bitsandbytes") @@ -2699,8 +2698,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if load_in_8bit and torch_dtype is None: logger.warning( "You are loading your model in 8bit but you did not specify a `torch_dtype` attribute." - "All non-linear modules will be loaded in full precision.", - " If you want to load the other modules in other precision, please specify a `torch_dtype` attribute.", + "All non-linear modules will be loaded in full precision." + " If you want to load the other modules in other precision, please specify a `torch_dtype` attribute." ) if isinstance(device_map, str): diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py index 5cb82c44d69..d9b4c037b4d 100644 --- a/src/transformers/utils/bitsandbytes.py +++ b/src/transformers/utils/bitsandbytes.py @@ -3,6 +3,7 @@ from copy import deepcopy from packaging import version +from ..utils import logging from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available @@ -15,6 +16,8 @@ if is_accelerate_available(): from accelerate import init_empty_weights from accelerate.utils import find_tied_parameters +logger = logging.get_logger(__name__) + def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None): """ @@ -106,6 +109,59 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non module._parameters[tensor_name] = new_value +def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + has_been_replaced = False + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + with init_empty_weights(): + if quantization_config.quantization_method() == "llm_int8": + model._modules[name] = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, + threshold=quantization_config.llm_int8_threshold, + ) + has_been_replaced = True + else: + if ( + quantization_config.llm_int8_skip_modules is not None + and name in quantization_config.llm_int8_skip_modules + ): + pass + else: + model._modules[name] = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + module.bias is not None, + quantization_config.bnb_4bit_compute_dtype, + compress_statistics=quantization_config.bnb_4bit_use_double_quant, + quant_type=quantization_config.bnb_4bit_quant_type, + ) + has_been_replaced = True + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + # Remove the last key for recursion + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_bnb_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + ) + return model, has_been_replaced + + def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): """ A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` @@ -133,47 +189,18 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name `disk`). """ modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert - for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] + model, has_been_replaced = _replace_with_bnb_linear( + model, modules_to_not_convert, current_key_name, quantization_config + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) - if isinstance(module, nn.Linear) and name not in modules_to_not_convert: - # Check if the current key is not in the `modules_to_not_convert` - if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): - with init_empty_weights(): - if quantization_config.quantization_method() == "llm_int8": - model._modules[name] = bnb.nn.Linear8bitLt( - module.in_features, - module.out_features, - module.bias is not None, - has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, - threshold=quantization_config.llm_int8_threshold, - ) - else: - if ( - quantization_config.llm_int8_skip_modules is not None - and name in quantization_config.llm_int8_skip_modules - ): - pass - else: - model._modules[name] = bnb.nn.Linear4bit( - module.in_features, - module.out_features, - module.bias is not None, - quantization_config.bnb_4bit_compute_dtype, - compress_statistics=quantization_config.bnb_4bit_use_double_quant, - quant_type=quantization_config.bnb_4bit_quant_type, - ) - # Force requires grad to False to avoid unexpected errors - model._modules[name].requires_grad_(False) - # Remove the last key for recursion - if len(list(module.children())) > 0: - replace_with_bnb_linear( - module, - modules_to_not_convert, - current_key_name, - quantization_config, - ) return model