diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index cf27ef6cc6e..2a5783c2611 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -1,5 +1,6 @@ from argparse import ArgumentParser from pathlib import Path +import os import torch import logging import json @@ -12,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm +from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule @@ -325,8 +327,13 @@ def main(): # Save a trained model logging.info("** ** * Saving fine-tuned model ** ** * ") model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = args.output_dir / "pytorch_model.bin" - torch.save(model_to_save.state_dict(), str(output_model_file)) + + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) + + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(args.output_dir) if __name__ == '__main__': diff --git a/examples/lm_finetuning/simple_lm_finetuning.py b/examples/lm_finetuning/simple_lm_finetuning.py index 610912675f4..368d6825c73 100644 --- a/examples/lm_finetuning/simple_lm_finetuning.py +++ b/examples/lm_finetuning/simple_lm_finetuning.py @@ -29,6 +29,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange +from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule @@ -614,9 +615,12 @@ def main(): # Save a trained model logger.info("** ** * Saving fine - tuned model ** ** * ") model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") + output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) + output_config_file = os.path.join(args.output_dir, CONFIG_NAME) if args.do_train: torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(args.output_dir) def _truncate_seq_pair(tokens_a, tokens_b, max_length):