mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fixed base model class name extraction from PeftModels (#27162)
* Fixed base model class name extraction from PeftModels * Changes to first unwrap the model then extract the base model name * Changed base_model to base_model.model to stay consistent with peft model abstractions
This commit is contained in:
parent
4991216841
commit
552ff24488
@ -646,7 +646,7 @@ class Trainer:
|
||||
unwrapped_model = unwrap_model(model)
|
||||
|
||||
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
|
||||
embeddings = unwrapped_model.base_model.get_input_embeddings()
|
||||
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
||||
else:
|
||||
embeddings = unwrapped_model.get_input_embeddings()
|
||||
|
||||
@ -667,7 +667,7 @@ class Trainer:
|
||||
unwrapped_model = unwrap_model(model)
|
||||
|
||||
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
|
||||
embeddings = unwrapped_model.base_model.get_input_embeddings()
|
||||
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
|
||||
else:
|
||||
embeddings = unwrapped_model.get_input_embeddings()
|
||||
|
||||
@ -2752,10 +2752,11 @@ class Trainer:
|
||||
self._past = outputs[self.args.past_index]
|
||||
|
||||
if labels is not None:
|
||||
if is_peft_available() and isinstance(model, PeftModel):
|
||||
model_name = unwrap_model(model.base_model)._get_name()
|
||||
unwrapped_model = unwrap_model(model)
|
||||
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
|
||||
model_name = unwrapped_model.base_model.model._get_name()
|
||||
else:
|
||||
model_name = unwrap_model(model)._get_name()
|
||||
model_name = unwrapped_model._get_name()
|
||||
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user