Add auto_wrap option in fairscale integration (#10673)

* Add auto_wrap option in fairscale integration

* Style
This commit is contained in:
Sylvain Gugger 2021-03-12 07:50:20 -05:00 committed by GitHub
parent 184ef8ecd0
commit e8246f78f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 6 deletions

View File

@ -335,8 +335,8 @@ Known caveats:
- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script.
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
:obj:`FullyShardedDataParallelism` of fairscale. This is not done automatically by any of the example scripts of the
:class:`~transformers.Trainer`.
:obj:`FullyShardedDataParallelism` of fairscale. It should be used with the option :obj:`auto_wrap` if you are not
doing this yourself: :obj:`--sharded_ddp "zero_dp_3 auto_wrap"`.
DeepSpeed

View File

@ -144,6 +144,7 @@ if is_fairscale_available():
if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.wrap import auto_wrap
else:
FullyShardedDDP = None
@ -775,8 +776,13 @@ class Trainer:
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
model = auto_wrap(model)
self.model = model = FullyShardedDDP(
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload
model,
mixed_precision=mixed_precision,
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)
elif is_sagemaker_distributed_available():

View File

@ -446,3 +446,4 @@ class ShardedDDPOption(ExplicitEnum):
ZERO_DP_2 = "zero_dp_2"
ZERO_DP_3 = "zero_dp_3"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"

View File

@ -470,10 +470,10 @@ class TrainingArguments:
sharded_ddp: str = field(
default="",
metadata={
"choices": ["simple", "zero_dp_2", "zero_dp_3", "zero_dp_2 offload", "zero_dp_3 offload"],
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
"like this: zero_dp_2 offload` or `zero_dp_3 offload`",
"like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or "
"with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.",
},
)
deepspeed: Optional[str] = field(
@ -570,7 +570,7 @@ class TrainingArguments:
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp:
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")