mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[bnb
] add warning when no linear (#23894)
* add warning for gpt2-like models * more details * adapt from suggestions
This commit is contained in:
parent
8f915c450d
commit
e42869b091
@ -2687,7 +2687,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model = replace_with_bnb_linear(
|
||||
model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
# training in 8-bit is only available in 0.37.0+
|
||||
model._is_quantized_training_enabled = version.parse(
|
||||
importlib_metadata.version("bitsandbytes")
|
||||
@ -2699,8 +2698,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if load_in_8bit and torch_dtype is None:
|
||||
logger.warning(
|
||||
"You are loading your model in 8bit but you did not specify a `torch_dtype` attribute."
|
||||
"All non-linear modules will be loaded in full precision.",
|
||||
" If you want to load the other modules in other precision, please specify a `torch_dtype` attribute.",
|
||||
"All non-linear modules will be loaded in full precision."
|
||||
" If you want to load the other modules in other precision, please specify a `torch_dtype` attribute."
|
||||
)
|
||||
|
||||
if isinstance(device_map, str):
|
||||
|
@ -3,6 +3,7 @@ from copy import deepcopy
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import logging
|
||||
from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available
|
||||
|
||||
|
||||
@ -15,6 +16,8 @@ if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import find_tied_parameters
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
|
||||
"""
|
||||
@ -106,6 +109,59 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
||||
module._parameters[tensor_name] = new_value
|
||||
|
||||
|
||||
def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
has_been_replaced = False
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||
with init_empty_weights():
|
||||
if quantization_config.quantization_method() == "llm_int8":
|
||||
model._modules[name] = bnb.nn.Linear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
|
||||
threshold=quantization_config.llm_int8_threshold,
|
||||
)
|
||||
has_been_replaced = True
|
||||
else:
|
||||
if (
|
||||
quantization_config.llm_int8_skip_modules is not None
|
||||
and name in quantization_config.llm_int8_skip_modules
|
||||
):
|
||||
pass
|
||||
else:
|
||||
model._modules[name] = bnb.nn.Linear4bit(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
quantization_config.bnb_4bit_compute_dtype,
|
||||
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
||||
quant_type=quantization_config.bnb_4bit_quant_type,
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
# Remove the last key for recursion
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_bnb_linear(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
||||
"""
|
||||
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
||||
@ -133,47 +189,18 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
||||
`disk`).
|
||||
"""
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
model, has_been_replaced = _replace_with_bnb_linear(
|
||||
model, modules_to_not_convert, current_key_name, quantization_config
|
||||
)
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||
with init_empty_weights():
|
||||
if quantization_config.quantization_method() == "llm_int8":
|
||||
model._modules[name] = bnb.nn.Linear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
|
||||
threshold=quantization_config.llm_int8_threshold,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
quantization_config.llm_int8_skip_modules is not None
|
||||
and name in quantization_config.llm_int8_skip_modules
|
||||
):
|
||||
pass
|
||||
else:
|
||||
model._modules[name] = bnb.nn.Linear4bit(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
quantization_config.bnb_4bit_compute_dtype,
|
||||
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
||||
quant_type=quantization_config.bnb_4bit_quant_type,
|
||||
)
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
# Remove the last key for recursion
|
||||
if len(list(module.children())) > 0:
|
||||
replace_with_bnb_linear(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user