FIX FSDP plugin update for QLoRA (#36720)

The _fsdp_qlora_plugin_updates checks for LoraConfig but other PEFT
methods can also support quantized models, e.g. VeRA. Therefore, the
isinstance check is now looking for PeftConfig in general.

Moreover, the fsdp_plugin variable may be undefined in the 2nd if
condition, leading to an `UnboundLocalError` error. This is fixed by not
assigning the variable at all.

I checked for tests that may need updating but only found
test_fsdp_config_transformers_auto_wrap associated with this change.
AFAICT, this test does not cover the changed code, since the test does
not start the training loop. Therefore, I haven't updated any tests. LMK
if/how this fix should be tested.

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Benjamin Bossan 2025-03-21 10:11:47 +01:00 committed by GitHub
parent 949cca4061
commit 6bb8565f0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5202,18 +5202,17 @@ class Trainer:
def _fsdp_qlora_plugin_updates(self):
if self.is_fsdp_enabled and _is_peft_model(self.model):
from peft import LoraConfig
from peft import PeftConfig
from peft.utils.other import fsdp_auto_wrap_policy
if isinstance(self.model.active_peft_config, LoraConfig):
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
if isinstance(self.model.active_peft_config, PeftConfig):
self.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
if (
getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
and version.parse(accelerate_version) > version.parse("0.27.0")
):
fsdp_plugin.set_mixed_precision(
self.accelerator.state.fsdp_plugin.set_mixed_precision(
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
)