mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
fixing error when using sharded ddp (#18435)
This commit is contained in:
parent
5096a654b7
commit
22a0dd2ef7
@ -1344,9 +1344,8 @@ class Trainer:
|
||||
reshard_after_forward=zero_3,
|
||||
cpu_offload=cpu_offload,
|
||||
).to(self.args.device)
|
||||
|
||||
# Distributed training using PyTorch FSDP
|
||||
if self.fsdp is not None:
|
||||
elif self.fsdp is not None:
|
||||
# PyTorch FSDP!
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
@ -1394,7 +1393,6 @@ class Trainer:
|
||||
)
|
||||
if FSDPOption.OFFLOAD not in self.args.fsdp:
|
||||
model.to(self.args.device)
|
||||
|
||||
elif is_sagemaker_dp_enabled():
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||
|
Loading…
Reference in New Issue
Block a user