diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8533eb109fa..351b3f7ae85 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -5202,18 +5202,17 @@ class Trainer: def _fsdp_qlora_plugin_updates(self): if self.is_fsdp_enabled and _is_peft_model(self.model): - from peft import LoraConfig + from peft import PeftConfig from peft.utils.other import fsdp_auto_wrap_policy - if isinstance(self.model.active_peft_config, LoraConfig): - fsdp_plugin = self.accelerator.state.fsdp_plugin - fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model) + if isinstance(self.model.active_peft_config, PeftConfig): + self.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model) if ( getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point and version.parse(accelerate_version) > version.parse("0.27.0") ): - fsdp_plugin.set_mixed_precision( + self.accelerator.state.fsdp_plugin.set_mixed_precision( self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True )