[bnb] Fix bnb skip modules (#24043)

* fix skip modules test

* oops

* address comments
This commit is contained in:
Younes Belkada 2023-06-07 15:27:46 +02:00 committed by GitHub
parent a1160185ff
commit 4795219228
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 3 deletions

View File

@ -109,16 +109,18 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
module._parameters[tensor_name] = new_value
def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
def _replace_with_bnb_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
has_been_replaced = False
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
@ -151,14 +153,16 @@ def _replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_nam
has_been_replaced = True
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# Remove the last key for recursion
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced

View File

@ -146,6 +146,26 @@ class MixedInt8Test(BaseMixedInt8Test):
if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
self.assertTrue(module.weight.dtype == torch.int8)
def test_llm_skip(self):
r"""
A simple test to check if `llm_int8_skip_modules` works as expected
"""
import bitsandbytes as bnb
quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["classifier"])
seq_classification_model = AutoModelForSequenceClassification.from_pretrained(
"roberta-large-mnli", quantization_config=quantization_config
)
self.assertTrue(seq_classification_model.roberta.encoder.layer[0].output.dense.weight.dtype == torch.int8)
self.assertTrue(
isinstance(seq_classification_model.roberta.encoder.layer[0].output.dense, bnb.nn.Linear8bitLt)
)
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
self.assertTrue(seq_classification_model.classifier.dense.weight.dtype != torch.int8)
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear))
self.assertTrue(seq_classification_model.classifier.out_proj != torch.int8)
def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.