diff --git a/examples/lm_finetuning/finetune_on_pregenerated.py b/examples/lm_finetuning/finetune_on_pregenerated.py index b0be0f7260d..9fcc5f2cb1a 100644 --- a/examples/lm_finetuning/finetune_on_pregenerated.py +++ b/examples/lm_finetuning/finetune_on_pregenerated.py @@ -322,14 +322,8 @@ def main(): # Save a trained model if n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 : logging.info("** ** * Saving fine-tuned model ** ** * ") - model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - - output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) - output_config_file = os.path.join(args.output_dir, CONFIG_NAME) - - torch.save(model_to_save.state_dict(), output_model_file) - model_to_save.config.to_json_file(output_config_file) - tokenizer.save_vocabulary(args.output_dir) + model.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) if __name__ == '__main__': diff --git a/examples/lm_finetuning/simple_lm_finetuning.py b/examples/lm_finetuning/simple_lm_finetuning.py index adb86b2ec99..ba5f8328273 100644 --- a/examples/lm_finetuning/simple_lm_finetuning.py +++ b/examples/lm_finetuning/simple_lm_finetuning.py @@ -32,7 +32,7 @@ from tqdm import tqdm, trange from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME from pytorch_transformers.modeling_bert import BertForPreTraining from pytorch_transformers.tokenization_bert import BertTokenizer -from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule +from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', @@ -610,12 +610,8 @@ def main(): # Save a trained model if args.do_train and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1): logger.info("** ** * Saving fine - tuned model ** ** * ") - model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self - output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) - output_config_file = os.path.join(args.output_dir, CONFIG_NAME) - torch.save(model_to_save.state_dict(), output_model_file) - model_to_save.config.to_json_file(output_config_file) - tokenizer.save_vocabulary(args.output_dir) + model.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) def _truncate_seq_pair(tokens_a, tokens_b, max_length):