Smangrul/accelerate ddp integrate (#23151)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments
This commit is contained in:
Sourab Mangrulkar 2023-05-31 13:42:49 +05:30 committed by GitHub
parent 9f0646a555
commit 1cf148a6aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -213,6 +213,7 @@ if is_accelerate_available():
from accelerate import skip_first_batches from accelerate import skip_first_batches
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1591,6 +1592,8 @@ class Trainer:
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
) )
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
if is_torch_neuroncore_available():
return model
kwargs = {} kwargs = {}
if self.args.ddp_find_unused_parameters is not None: if self.args.ddp_find_unused_parameters is not None:
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
@ -1603,15 +1606,8 @@ class Trainer:
if self.args.ddp_bucket_cap_mb is not None: if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if is_torch_neuroncore_available():
return model self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
if any(p.requires_grad for p in model.parameters()):
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
# torch.compile() needs to be called after wrapping the model with FSDP or DDP # torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers # to ensure that it accounts for the graph breaks required by those wrappers