Support PeftModel signature inspect (#27865)

* Support PeftModel signature inspect

* Use get_base_model() to get the base model

---------

Co-authored-by: shujunhua1 <shujunhua1@jd.com>
This commit is contained in:
dancingpipi 2023-12-12 03:30:11 +08:00 committed by GitHub
parent 35478182ce
commit e5079b0b2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -695,7 +695,10 @@ class Trainer:
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
model_to_inspect = self.model
if is_peft_available() and isinstance(self.model, PeftModel):
model_to_inspect = self.model.get_base_model()
signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))