From 22a0dd2ef754676ede6009cd8e86b10b053ac2cc Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 3 Aug 2022 08:39:58 +0530 Subject: [PATCH] fixing error when using sharded ddp (#18435) --- src/transformers/trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 59a1ca19a62..37a21b0939c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1344,9 +1344,8 @@ class Trainer: reshard_after_forward=zero_3, cpu_offload=cpu_offload, ).to(self.args.device) - # Distributed training using PyTorch FSDP - if self.fsdp is not None: + elif self.fsdp is not None: # PyTorch FSDP! from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -1394,7 +1393,6 @@ class Trainer: ) if FSDPOption.OFFLOAD not in self.args.fsdp: model.to(self.args.device) - elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]