mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Disable DDP for neuron (#21953)
Disable DDp for neuron Co-authored-by: EC2 Default User <ec2-user@ip-172-31-42-72.us-west-2.compute.internal>
This commit is contained in:
parent
bc33fbf956
commit
0bb17295f0
@ -148,6 +148,7 @@ from .utils import (
|
|||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_torch_compile_available,
|
is_torch_compile_available,
|
||||||
|
is_torch_neuroncore_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
@ -1537,6 +1538,8 @@ class Trainer:
|
|||||||
|
|
||||||
if self.args.ddp_bucket_cap_mb is not None:
|
if self.args.ddp_bucket_cap_mb is not None:
|
||||||
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
|
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
|
||||||
|
if is_torch_neuroncore_available:
|
||||||
|
return model
|
||||||
model = nn.parallel.DistributedDataParallel(
|
model = nn.parallel.DistributedDataParallel(
|
||||||
model,
|
model,
|
||||||
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
|
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
|
||||||
|
Loading…
Reference in New Issue
Block a user