mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
fixing error when using sharded ddp (#18435)
This commit is contained in:
parent
5096a654b7
commit
22a0dd2ef7
@ -1344,9 +1344,8 @@ class Trainer:
|
|||||||
reshard_after_forward=zero_3,
|
reshard_after_forward=zero_3,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
).to(self.args.device)
|
).to(self.args.device)
|
||||||
|
|
||||||
# Distributed training using PyTorch FSDP
|
# Distributed training using PyTorch FSDP
|
||||||
if self.fsdp is not None:
|
elif self.fsdp is not None:
|
||||||
# PyTorch FSDP!
|
# PyTorch FSDP!
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||||
@ -1394,7 +1393,6 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
if FSDPOption.OFFLOAD not in self.args.fsdp:
|
if FSDPOption.OFFLOAD not in self.args.fsdp:
|
||||||
model.to(self.args.device)
|
model.to(self.args.device)
|
||||||
|
|
||||||
elif is_sagemaker_dp_enabled():
|
elif is_sagemaker_dp_enabled():
|
||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||||
|
Loading…
Reference in New Issue
Block a user