mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[bnb
] Fix bnb skip modules (#24043)
* fix skip modules test * oops * address comments
This commit is contained in:
parent
a1160185ff
commit
4795219228
@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
||||
module._parameters[tensor_name] = new_value
|
||||
|
||||
|
||||
def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
||||
def _replace_with_bnb_linear(
|
||||
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
has_been_replaced = False
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam
|
||||
has_been_replaced = True
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
# Remove the last key for recursion
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_bnb_linear(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
|
@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
|
||||
self.assertTrue(module.weight.dtype == torch.int8)
|
||||
|
||||
def test_llm_skip(self):
|
||||
r"""
|
||||
A simple test to check if `llm_int8_skip_modules` works as expected
|
||||
"""
|
||||
import bitsandbytes as bnb
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"])
|
||||
seq_classification_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"roberta-large-mnli", quantization_config=quantization_config
|
||||
)
|
||||
self.assertTrue(seq_classification_model.roberta.encoder.layer[0].output.dense.weight.dtype == torch.int8)
|
||||
self.assertTrue(
|
||||
isinstance(seq_classification_model.roberta.encoder.layer[0].output.dense, bnb.nn.Linear8bitLt)
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
|
||||
self.assertTrue(seq_classification_model.classifier.dense.weight.dtype != torch.int8)
|
||||
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear))
|
||||
self.assertTrue(seq_classification_model.classifier.out_proj != torch.int8)
|
||||
|
||||
def test_generate_quality(self):
|
||||
r"""
|
||||
Test the generation quality of the quantized model and see that we are matching the expected output.
|
||||
|
Loading…
Reference in New Issue
Block a user