fix peft ckpts not being pushed to hub (#24578)

* fix push to hub for peft ckpts

* oops
This commit is contained in:
Sourab Mangrulkar 2023-06-30 00:07:44 +05:30 committed by GitHub
parent 232c898f9f
commit 9e28750287
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 0 deletions

View File

@ -119,6 +119,7 @@ from .trainer_utils import (
)
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import (
ADAPTER_CONFIG_NAME,
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
@ -3533,6 +3534,8 @@ class Trainer:
output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))

View File

@ -178,6 +178,7 @@ from .import_utils import (
WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
ADAPTER_CONFIG_NAME = "adapter_config.json"
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
TF2_WEIGHTS_NAME = "tf_model.h5"