mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Smangrul/accelerate ddp integrate (#23151)
* mixed precision support via accelerate
* fix issues
* fix for the sharded ddp case
* fix flax and tf failing tests
* `refactor the place to create `Accelerator` object
* move ddp prep to accelerate
* fix 😅
* resolving comments
This commit is contained in:
parent
9f0646a555
commit
1cf148a6aa
@ -213,6 +213,7 @@ if is_accelerate_available():
|
||||
from accelerate import skip_first_batches
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1591,6 +1592,8 @@ class Trainer:
|
||||
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||
)
|
||||
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if is_torch_neuroncore_available():
|
||||
return model
|
||||
kwargs = {}
|
||||
if self.args.ddp_find_unused_parameters is not None:
|
||||
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
|
||||
@ -1603,15 +1606,8 @@ class Trainer:
|
||||
|
||||
if self.args.ddp_bucket_cap_mb is not None:
|
||||
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
|
||||
if is_torch_neuroncore_available():
|
||||
return model
|
||||
if any(p.requires_grad for p in model.parameters()):
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
|
||||
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
|
||||
|
||||
# torch.compile() needs to be called after wrapping the model with FSDP or DDP
|
||||
# to ensure that it accounts for the graph breaks required by those wrappers
|
||||
|
Loading…
Reference in New Issue
Block a user