mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix peft ckpts not being pushed to hub (#24578)
* fix push to hub for peft ckpts * oops
This commit is contained in:
parent
232c898f9f
commit
9e28750287
@ -119,6 +119,7 @@ from .trainer_utils import (
|
||||
)
|
||||
from .training_args import OptimizerNames, ParallelMode, TrainingArguments
|
||||
from .utils import (
|
||||
ADAPTER_CONFIG_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
CONFIG_NAME,
|
||||
@ -3533,6 +3534,8 @@ class Trainer:
|
||||
output_dir = self.args.output_dir
|
||||
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
|
||||
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
|
||||
if is_peft_available():
|
||||
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
|
||||
for modeling_file in modeling_files:
|
||||
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
|
||||
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
|
||||
|
@ -178,6 +178,7 @@ from .import_utils import (
|
||||
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
ADAPTER_CONFIG_NAME = "adapter_config.json"
|
||||
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
|
||||
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
|
||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||
|
Loading…
Reference in New Issue
Block a user