mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix some errors for distributed lm_finetuning
This commit is contained in:
parent
5fe0b378d8
commit
60a1bdcdac
@ -504,7 +504,7 @@ def main():
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir):
|
||||
if not os.path.exists(args.output_dir) and torch.distributed.get_rank() == 0:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
@ -613,11 +613,11 @@ def main():
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
|
||||
if args.do_train:
|
||||
if args.do_train and torch.distributed.get_rank() == 0:
|
||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
tokenizer.save_vocabulary(args.output_dir)
|
||||
|
Loading…
Reference in New Issue
Block a user