mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Merge pull request #672 from oliverguhr/master
Add vocabulary and model config to the finetune output
This commit is contained in:
commit
cad88e19de
@ -1,5 +1,6 @@
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
@ -12,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
||||||
@ -325,8 +327,13 @@ def main():
|
|||||||
# Save a trained model
|
# Save a trained model
|
||||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||||
output_model_file = args.output_dir / "pytorch_model.bin"
|
|
||||||
torch.save(model_to_save.state_dict(), str(output_model_file))
|
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||||
|
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||||
|
|
||||||
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
model_to_save.config.to_json_file(output_config_file)
|
||||||
|
tokenizer.save_vocabulary(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -29,6 +29,7 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
|
||||||
@ -614,9 +615,12 @@ def main():
|
|||||||
# Save a trained model
|
# Save a trained model
|
||||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||||
|
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
|
model_to_save.config.to_json_file(output_config_file)
|
||||||
|
tokenizer.save_vocabulary(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
||||||
|
Loading…
Reference in New Issue
Block a user