Fix bnb training test failure (#34414)

* Fix bnb training test: compatibility with OPTSdpaAttention
This commit is contained in:
Matthew Douglas 2024-10-25 10:23:20 -04:00 committed by GitHub
parent 186b8dc190
commit e447185b1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 2 deletions

View File

@ -29,6 +29,7 @@ from transformers import (
BitsAndBytesConfig,
pipeline,
)
from transformers.models.opt.modeling_opt import OPTAttention
from transformers.testing_utils import (
apply_skip_if_not_implemented,
is_bitsandbytes_available,
@ -565,7 +566,7 @@ class Bnb4BitTestTraining(Base4bitTest):
# Step 2: add adapters
for _, module in model.named_modules():
if "OPTAttention" in repr(type(module)):
if isinstance(module, OPTAttention):
module.q_proj = LoRALayer(module.q_proj, rank=16)
module.k_proj = LoRALayer(module.k_proj, rank=16)
module.v_proj = LoRALayer(module.v_proj, rank=16)

View File

@ -29,6 +29,7 @@ from transformers import (
BitsAndBytesConfig,
pipeline,
)
from transformers.models.opt.modeling_opt import OPTAttention
from transformers.testing_utils import (
apply_skip_if_not_implemented,
is_accelerate_available,
@ -868,7 +869,7 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
# Step 2: add adapters
for _, module in model.named_modules():
if "OPTAttention" in repr(type(module)):
if isinstance(module, OPTAttention):
module.q_proj = LoRALayer(module.q_proj, rank=16)
module.k_proj = LoRALayer(module.k_proj, rank=16)
module.v_proj = LoRALayer(module.v_proj, rank=16)