Fixed base model class name extraction from PeftModels (#27162)

* Fixed base model class name extraction from PeftModels

* Changes to first unwrap the model then extract the base model name

* Changed base_model to base_model.model to stay consistent with peft model abstractions
This commit is contained in:
Komal Kumar 2023-11-02 16:08:03 -04:00 committed by GitHub
parent 4991216841
commit 552ff24488
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -646,7 +646,7 @@ class Trainer:
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.get_input_embeddings()
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
@ -667,7 +667,7 @@ class Trainer:
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.get_input_embeddings()
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
@ -2752,10 +2752,11 @@ class Trainer:
self._past = outputs[self.args.past_index]
if labels is not None:
if is_peft_available() and isinstance(model, PeftModel):
model_name = unwrap_model(model.base_model)._get_name()
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrap_model(model)._get_name()
model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else: