mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Support PeftModel signature inspect (#27865)
* Support PeftModel signature inspect * Use get_base_model() to get the base model --------- Co-authored-by: shujunhua1 <shujunhua1@jd.com>
This commit is contained in:
parent
35478182ce
commit
e5079b0b2a
@ -695,7 +695,10 @@ class Trainer:
|
||||
def _set_signature_columns_if_needed(self):
|
||||
if self._signature_columns is None:
|
||||
# Inspect model forward signature to keep only the arguments it accepts.
|
||||
signature = inspect.signature(self.model.forward)
|
||||
model_to_inspect = self.model
|
||||
if is_peft_available() and isinstance(self.model, PeftModel):
|
||||
model_to_inspect = self.model.get_base_model()
|
||||
signature = inspect.signature(model_to_inspect.forward)
|
||||
self._signature_columns = list(signature.parameters.keys())
|
||||
# Labels may be named label or label_ids, the default data collator handles that.
|
||||
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
|
||||
|
Loading…
Reference in New Issue
Block a user