Add bitsandbytes support for gpt2 models (#24504)

* Add bitsandbytes support for gpt2 models

* Guard Conv1D import to pass tensorflow test

* Appease ruff linter

* Fix 4bit test and remove int8 test boilerplate

* Update tests/bnb/test_mixed_int8.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Dario Sučić 2023-06-28 05:55:32 +02:00 committed by GitHub
parent 89b6ee49fd
commit 12240925cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 14 deletions

View File

@ -12,6 +12,8 @@ if is_bitsandbytes_available():
import torch
import torch.nn as nn
from ..pytorch_utils import Conv1D
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters
@ -84,6 +86,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
else:
new_value = torch.tensor(value, device="cpu")
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
new_value = new_value.T
kwargs = old_value.__dict__
if is_8bit:
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
@ -122,14 +129,20 @@ def _replace_with_bnb_linear(
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if isinstance(module, Conv1D):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features
if quantization_config.quantization_method() == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
in_features,
out_features,
module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold,
@ -143,14 +156,16 @@ def _replace_with_bnb_linear(
pass
else:
model._modules[name] = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
@ -200,7 +215,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)

View File

@ -39,6 +39,12 @@ from transformers.testing_utils import (
from transformers.utils.versions import importlib_metadata
def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h
if is_torch_available():
import torch
import torch.nn as nn
@ -83,6 +89,7 @@ class Base4bitTest(unittest.TestCase):
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
MAX_NEW_TOKENS = 10
def setUp(self):
@ -135,7 +142,8 @@ class Bnb4BitTest(Base4bitTest):
mem_4bit = self.model_4bit.get_memory_footprint()
self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == Params4bit)
def test_linear_are_4bit(self):
r"""
@ -473,3 +481,8 @@ class Bnb4BitTestTraining(Base4bitTest):
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
elif isinstance(module, nn.Embedding):
self.assertTrue(module.weight.grad is None)
class Bnb4BitGPT2Test(Bnb4BitTest):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187

View File

@ -41,6 +41,12 @@ from transformers.testing_utils import (
from transformers.utils.versions import importlib_metadata
def get_some_linear_layer(model):
if model.config.model_type == "gpt2":
return model.transformer.h[0].mlp.c_fc
return model.transformer.h[0].mlp.dense_4h_to_h
if is_accelerate_available():
from accelerate import PartialState
from accelerate.logging import get_logger
@ -142,7 +148,7 @@ class MixedInt8Test(BaseMixedInt8Test):
mem_8bit = self.model_8bit.get_memory_footprint()
self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(get_some_linear_layer(self.model_8bit).weight.__class__ == Int8Params)
def test_linear_are_8bit(self):
r"""
@ -292,8 +298,9 @@ class MixedInt8Test(BaseMixedInt8Test):
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
linear = get_some_linear_layer(model_from_saved)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
@ -318,8 +325,9 @@ class MixedInt8Test(BaseMixedInt8Test):
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
linear = get_some_linear_layer(model_from_saved)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
@ -339,8 +347,9 @@ class MixedInt8Test(BaseMixedInt8Test):
model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
linear = get_some_linear_layer(model)
self.assertTrue(linear.weight.__class__ == Int8Params)
self.assertTrue(hasattr(linear.weight, "SCB"))
# generate
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
@ -748,3 +757,13 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
elif isinstance(module, nn.Embedding):
self.assertTrue(module.weight.grad is None)
class MixedInt8GPT2Test(MixedInt8Test):
model_name = "gpt2-xl"
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the"
def test_int8_from_pretrained(self):
# TODO @younesbelkada: Test loading quantized gpt2 model from the hub.
pass