Merge pull request #896 from zijunsun/master

fix multi-gpu training bug when using fp16
This commit is contained in:
Thomas Wolf 2019-07-26 19:31:02 +02:00 committed by GitHub
commit c054b5ee64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 4 deletions

View File

@ -92,6 +92,10 @@ def train(args, train_dataset, model, tokenizer):
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
@ -418,8 +422,6 @@ def main():
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
model.to(args.device)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
logger.info("Training/evaluation parameters %s", args)

View File

@ -101,6 +101,10 @@ def train(args, train_dataset, model, tokenizer):
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
@ -460,8 +464,6 @@ def main():
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
model.to(args.device)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
logger.info("Training/evaluation parameters %s", args)