[bnb] fix bnb decoders bug (#21688)

* fix `bnb` decoders bug

* make fixup
This commit is contained in:
Younes Belkada 2023-02-20 13:21:58 +01:00 committed by GitHub
parent f56174ac5b
commit c9a0671477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 1 deletions

View File

@ -171,4 +171,13 @@ def get_keys_to_not_convert(model):
intersection = set(list_last_module) - set(tied_keys)
list_untouched = tied_keys + list(intersection)
return [module_name.split(".")[0] for module_name in list_untouched]
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names

View File

@ -269,10 +269,16 @@ class MixedInt8T5Test(unittest.TestCase):
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
both cases.
"""
import bitsandbytes as bnb
from transformers import T5ForConditionalGeneration
# test with `t5-small`
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
# there was a bug with decoders - this test checks that it is fixed
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)