mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[bnb
] fix bnb
decoders bug (#21688)
* fix `bnb` decoders bug * make fixup
This commit is contained in:
parent
f56174ac5b
commit
c9a0671477
@ -171,4 +171,13 @@ def get_keys_to_not_convert(model):
|
||||
intersection = set(list_last_module) - set(tied_keys)
|
||||
list_untouched = tied_keys + list(intersection)
|
||||
|
||||
return [module_name.split(".")[0] for module_name in list_untouched]
|
||||
# remove ".weight" from the keys
|
||||
names_to_remove = [".weight", ".bias"]
|
||||
filtered_module_names = []
|
||||
for name in list_untouched:
|
||||
for name_to_remove in names_to_remove:
|
||||
if name_to_remove in name:
|
||||
name = name.replace(name_to_remove, "")
|
||||
filtered_module_names.append(name)
|
||||
|
||||
return filtered_module_names
|
||||
|
@ -269,10 +269,16 @@ class MixedInt8T5Test(unittest.TestCase):
|
||||
`flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
|
||||
both cases.
|
||||
"""
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
# test with `t5-small`
|
||||
model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_8bit=True, device_map="auto")
|
||||
|
||||
# there was a bug with decoders - this test checks that it is fixed
|
||||
self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear8bitLt))
|
||||
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
|
||||
_ = model.generate(**encoded_input)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user