mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix bnb training test failure (#34414)
* Fix bnb training test: compatibility with OPTSdpaAttention
This commit is contained in:
parent
186b8dc190
commit
e447185b1f
@ -29,6 +29,7 @@ from transformers import (
|
||||
BitsAndBytesConfig,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
from transformers.testing_utils import (
|
||||
apply_skip_if_not_implemented,
|
||||
is_bitsandbytes_available,
|
||||
@ -565,7 +566,7 @@ class Bnb4BitTestTraining(Base4bitTest):
|
||||
|
||||
# Step 2: add adapters
|
||||
for _, module in model.named_modules():
|
||||
if "OPTAttention" in repr(type(module)):
|
||||
if isinstance(module, OPTAttention):
|
||||
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
||||
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
||||
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
||||
|
@ -29,6 +29,7 @@ from transformers import (
|
||||
BitsAndBytesConfig,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.models.opt.modeling_opt import OPTAttention
|
||||
from transformers.testing_utils import (
|
||||
apply_skip_if_not_implemented,
|
||||
is_accelerate_available,
|
||||
@ -868,7 +869,7 @@ class MixedInt8TestTraining(BaseMixedInt8Test):
|
||||
|
||||
# Step 2: add adapters
|
||||
for _, module in model.named_modules():
|
||||
if "OPTAttention" in repr(type(module)):
|
||||
if isinstance(module, OPTAttention):
|
||||
module.q_proj = LoRALayer(module.q_proj, rank=16)
|
||||
module.k_proj = LoRALayer(module.k_proj, rank=16)
|
||||
module.v_proj = LoRALayer(module.v_proj, rank=16)
|
||||
|
Loading…
Reference in New Issue
Block a user