update create_model_card to properly save peft details when using Trainer with PEFT (#27754)

* update `create_model_card` to properly save peft details when using Trainer with PEFT

* nit

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
Sourab Mangrulkar 2023-12-07 17:36:02 +05:30 committed by GitHub
parent 52746922b0
commit 5324bf9c07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -48,7 +48,7 @@ import huggingface_hub.utils as hf_hub_utils
import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import create_repo, upload_folder
from huggingface_hub import ModelCard, create_repo, upload_folder
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
@ -3494,6 +3494,12 @@ class Trainer:
if not self.is_world_process_zero():
return
model_card_filepath = os.path.join(self.args.output_dir, "README.md")
is_peft_library = False
if os.path.exists(model_card_filepath):
library_name = ModelCard.load(model_card_filepath).data.get("library_name")
is_peft_library = library_name == "peft"
training_summary = TrainingSummary.from_trainer(
self,
language=language,
@ -3507,9 +3513,12 @@ class Trainer:
dataset_args=dataset_args,
)
model_card = training_summary.to_model_card()
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
with open(model_card_filepath, "w") as f:
f.write(model_card)
if is_peft_library:
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node.
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: