fix saving models in distributed setting examples

This commit is contained in:
thomwolf 2019-04-15 16:43:56 +02:00
parent d616022455
commit 3571187ef6
2 changed files with 2 additions and 1 deletions

View File

@ -859,6 +859,7 @@ def main():
optimizer.zero_grad()
global_step += 1
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self

View File

@ -1020,7 +1020,7 @@ def main():
optimizer.zero_grad()
global_step += 1
if args.do_train:
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self