mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add bitsandbytes support for gpt2 models (#24504)
* Add bitsandbytes support for gpt2 models * Guard Conv1D import to pass tensorflow test * Appease ruff linter * Fix 4bit test and remove int8 test boilerplate * Update tests/bnb/test_mixed_int8.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
89b6ee49fd
commit
12240925cf
@ -12,6 +12,8 @@ if is_bitsandbytes_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..pytorch_utils import Conv1D
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import find_tied_parameters
|
||||
@ -84,6 +86,11 @@ def set_module_quantized_tensor_to_device(module, tensor_name, device, value=Non
|
||||
else:
|
||||
new_value = torch.tensor(value, device="cpu")
|
||||
|
||||
# Support models using `Conv1D` in place of `nn.Linear` (e.g. gpt2) by transposing the weight matrix prior to quantization.
|
||||
# Since weights are saved in the correct "orientation", we skip transposing when loading.
|
||||
if issubclass(module.source_cls, Conv1D) and fp16_statistics is None:
|
||||
new_value = new_value.T
|
||||
|
||||
kwargs = old_value.__dict__
|
||||
if is_8bit:
|
||||
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
|
||||
@ -122,14 +129,20 @@ def _replace_with_bnb_linear(
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
if (isinstance(module, nn.Linear) or isinstance(module, Conv1D)) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||
with init_empty_weights():
|
||||
if isinstance(module, Conv1D):
|
||||
in_features, out_features = module.weight.shape
|
||||
else:
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
if quantization_config.quantization_method() == "llm_int8":
|
||||
model._modules[name] = bnb.nn.Linear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
|
||||
threshold=quantization_config.llm_int8_threshold,
|
||||
@ -143,14 +156,16 @@ def _replace_with_bnb_linear(
|
||||
pass
|
||||
else:
|
||||
model._modules[name] = bnb.nn.Linear4bit(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
quantization_config.bnb_4bit_compute_dtype,
|
||||
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
||||
quant_type=quantization_config.bnb_4bit_quant_type,
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module)
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
if len(list(module.children())) > 0:
|
||||
@ -200,7 +215,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
||||
" this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
@ -39,6 +39,12 @@ from transformers.testing_utils import (
|
||||
from transformers.utils.versions import importlib_metadata
|
||||
|
||||
|
||||
def get_some_linear_layer(model):
|
||||
if model.config.model_type == "gpt2":
|
||||
return model.transformer.h[0].mlp.c_fc
|
||||
return model.transformer.h[0].mlp.dense_4h_to_h
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -83,6 +89,7 @@ class Base4bitTest(unittest.TestCase):
|
||||
EXPECTED_OUTPUTS = set()
|
||||
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
|
||||
EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
|
||||
EXPECTED_OUTPUTS.add("Hello my name is John Doe, I am a student at the University")
|
||||
MAX_NEW_TOKENS = 10
|
||||
|
||||
def setUp(self):
|
||||
@ -135,7 +142,8 @@ class Bnb4BitTest(Base4bitTest):
|
||||
mem_4bit = self.model_4bit.get_memory_footprint()
|
||||
|
||||
self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
|
||||
self.assertTrue(self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
|
||||
linear = get_some_linear_layer(self.model_4bit)
|
||||
self.assertTrue(linear.weight.__class__ == Params4bit)
|
||||
|
||||
def test_linear_are_4bit(self):
|
||||
r"""
|
||||
@ -473,3 +481,8 @@ class Bnb4BitTestTraining(Base4bitTest):
|
||||
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
self.assertTrue(module.weight.grad is None)
|
||||
|
||||
|
||||
class Bnb4BitGPT2Test(Bnb4BitTest):
|
||||
model_name = "gpt2-xl"
|
||||
EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187
|
||||
|
@ -41,6 +41,12 @@ from transformers.testing_utils import (
|
||||
from transformers.utils.versions import importlib_metadata
|
||||
|
||||
|
||||
def get_some_linear_layer(model):
|
||||
if model.config.model_type == "gpt2":
|
||||
return model.transformer.h[0].mlp.c_fc
|
||||
return model.transformer.h[0].mlp.dense_4h_to_h
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.logging import get_logger
|
||||
@ -142,7 +148,7 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
mem_8bit = self.model_8bit.get_memory_footprint()
|
||||
|
||||
self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE)
|
||||
self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(get_some_linear_layer(self.model_8bit).weight.__class__ == Int8Params)
|
||||
|
||||
def test_linear_are_8bit(self):
|
||||
r"""
|
||||
@ -292,8 +298,9 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, load_in_8bit=True, device_map="auto")
|
||||
|
||||
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
||||
linear = get_some_linear_layer(model_from_saved)
|
||||
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
@ -318,8 +325,9 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(model_from_saved.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
||||
linear = get_some_linear_layer(model_from_saved)
|
||||
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
@ -339,8 +347,9 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(model.transformer.h[0].mlp.dense_4h_to_h.weight, "SCB"))
|
||||
linear = get_some_linear_layer(model)
|
||||
self.assertTrue(linear.weight.__class__ == Int8Params)
|
||||
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||
|
||||
# generate
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
@ -748,3 +757,13 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
self.assertTrue(module.weight.grad is None)
|
||||
|
||||
|
||||
class MixedInt8GPT2Test(MixedInt8Test):
|
||||
model_name = "gpt2-xl"
|
||||
EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357
|
||||
EXPECTED_OUTPUT = "Hello my name is John Doe, and I am a member of the"
|
||||
|
||||
def test_int8_from_pretrained(self):
|
||||
# TODO @younesbelkada: Test loading quantized gpt2 model from the hub.
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user