fixing error when using sharded ddp (#18435)

This commit is contained in:
Sourab Mangrulkar 2022-08-03 08:39:58 +05:30 committed by GitHub
parent 5096a654b7
commit 22a0dd2ef7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1344,9 +1344,8 @@ class Trainer:
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)
# Distributed training using PyTorch FSDP
if self.fsdp is not None:
elif self.fsdp is not None:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
@ -1394,7 +1393,6 @@ class Trainer:
)
if FSDPOption.OFFLOAD not in self.args.fsdp:
model.to(self.args.device)
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]