mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Merge pull request #1040 from FeiWang96/multi_gpu
Fix bug of multi-gpu training in lm finetuning
This commit is contained in:
commit
3b56427a1e
@ -320,9 +320,10 @@ def main():
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# Save a trained model
|
# Save a trained model
|
||||||
if n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 :
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||||
model.save_pretrained(args.output_dir)
|
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@ -507,7 +507,7 @@ def main():
|
|||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
if not os.path.exists(args.output_dir) and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 ):
|
if not os.path.exists(args.output_dir) and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
@ -608,9 +608,10 @@ def main():
|
|||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# Save a trained model
|
# Save a trained model
|
||||||
if args.do_train and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
logger.info("** ** * Saving fine - tuned model ** ** * ")
|
||||||
model.save_pretrained(args.output_dir)
|
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||||
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user