mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[bnb
] add warning when no linear (#23894)
* add warning for gpt2-like models * more details * adapt from suggestions
This commit is contained in:
parent
8f915c450d
commit
e42869b091
@ -2687,7 +2687,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
model = replace_with_bnb_linear(
|
model = replace_with_bnb_linear(
|
||||||
model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config
|
model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config
|
||||||
)
|
)
|
||||||
|
|
||||||
# training in 8-bit is only available in 0.37.0+
|
# training in 8-bit is only available in 0.37.0+
|
||||||
model._is_quantized_training_enabled = version.parse(
|
model._is_quantized_training_enabled = version.parse(
|
||||||
importlib_metadata.version("bitsandbytes")
|
importlib_metadata.version("bitsandbytes")
|
||||||
@ -2699,8 +2698,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if load_in_8bit and torch_dtype is None:
|
if load_in_8bit and torch_dtype is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"You are loading your model in 8bit but you did not specify a `torch_dtype` attribute."
|
"You are loading your model in 8bit but you did not specify a `torch_dtype` attribute."
|
||||||
"All non-linear modules will be loaded in full precision.",
|
"All non-linear modules will be loaded in full precision."
|
||||||
" If you want to load the other modules in other precision, please specify a `torch_dtype` attribute.",
|
" If you want to load the other modules in other precision, please specify a `torch_dtype` attribute."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(device_map, str):
|
if isinstance(device_map, str):
|
||||||
|
@ -3,6 +3,7 @@ from copy import deepcopy
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available
|
from .import_utils import importlib_metadata, is_accelerate_available, is_bitsandbytes_available
|
||||||
|
|
||||||
|
|
||||||
@ -15,6 +16,8 @@ if is_accelerate_available():
|
|||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils import find_tied_parameters
|
from accelerate.utils import find_tied_parameters
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
|
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
|
||||||
"""
|
"""
|
||||||
@ -106,6 +109,59 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
|||||||
module._parameters[tensor_name] = new_value
|
module._parameters[tensor_name] = new_value
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
||||||
|
"""
|
||||||
|
Private method that wraps the recursion for module replacement.
|
||||||
|
|
||||||
|
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||||
|
"""
|
||||||
|
has_been_replaced = False
|
||||||
|
for name, module in model.named_children():
|
||||||
|
if current_key_name is None:
|
||||||
|
current_key_name = []
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||||
|
# Check if the current key is not in the `modules_to_not_convert`
|
||||||
|
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||||
|
with init_empty_weights():
|
||||||
|
if quantization_config.quantization_method() == "llm_int8":
|
||||||
|
model._modules[name] = bnb.nn.Linear8bitLt(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
|
||||||
|
threshold=quantization_config.llm_int8_threshold,
|
||||||
|
)
|
||||||
|
has_been_replaced = True
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
quantization_config.llm_int8_skip_modules is not None
|
||||||
|
and name in quantization_config.llm_int8_skip_modules
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
model._modules[name] = bnb.nn.Linear4bit(
|
||||||
|
module.in_features,
|
||||||
|
module.out_features,
|
||||||
|
module.bias is not None,
|
||||||
|
quantization_config.bnb_4bit_compute_dtype,
|
||||||
|
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
||||||
|
quant_type=quantization_config.bnb_4bit_quant_type,
|
||||||
|
)
|
||||||
|
has_been_replaced = True
|
||||||
|
# Force requires grad to False to avoid unexpected errors
|
||||||
|
model._modules[name].requires_grad_(False)
|
||||||
|
# Remove the last key for recursion
|
||||||
|
if len(list(module.children())) > 0:
|
||||||
|
_, has_been_replaced = _replace_with_bnb_linear(
|
||||||
|
module,
|
||||||
|
modules_to_not_convert,
|
||||||
|
current_key_name,
|
||||||
|
quantization_config,
|
||||||
|
)
|
||||||
|
return model, has_been_replaced
|
||||||
|
|
||||||
|
|
||||||
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
||||||
"""
|
"""
|
||||||
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
||||||
@ -133,47 +189,18 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
|||||||
`disk`).
|
`disk`).
|
||||||
"""
|
"""
|
||||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||||
for name, module in model.named_children():
|
model, has_been_replaced = _replace_with_bnb_linear(
|
||||||
if current_key_name is None:
|
model, modules_to_not_convert, current_key_name, quantization_config
|
||||||
current_key_name = []
|
)
|
||||||
|
|
||||||
|
if not has_been_replaced:
|
||||||
|
logger.warning(
|
||||||
|
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||||
|
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
|
||||||
|
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||||
|
" a bug."
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
|
||||||
# Check if the current key is not in the `modules_to_not_convert`
|
|
||||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
|
||||||
with init_empty_weights():
|
|
||||||
if quantization_config.quantization_method() == "llm_int8":
|
|
||||||
model._modules[name] = bnb.nn.Linear8bitLt(
|
|
||||||
module.in_features,
|
|
||||||
module.out_features,
|
|
||||||
module.bias is not None,
|
|
||||||
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
|
|
||||||
threshold=quantization_config.llm_int8_threshold,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
quantization_config.llm_int8_skip_modules is not None
|
|
||||||
and name in quantization_config.llm_int8_skip_modules
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
model._modules[name] = bnb.nn.Linear4bit(
|
|
||||||
module.in_features,
|
|
||||||
module.out_features,
|
|
||||||
module.bias is not None,
|
|
||||||
quantization_config.bnb_4bit_compute_dtype,
|
|
||||||
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
|
||||||
quant_type=quantization_config.bnb_4bit_quant_type,
|
|
||||||
)
|
|
||||||
# Force requires grad to False to avoid unexpected errors
|
|
||||||
model._modules[name].requires_grad_(False)
|
|
||||||
# Remove the last key for recursion
|
|
||||||
if len(list(module.children())) > 0:
|
|
||||||
replace_with_bnb_linear(
|
|
||||||
module,
|
|
||||||
modules_to_not_convert,
|
|
||||||
current_key_name,
|
|
||||||
quantization_config,
|
|
||||||
)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user