diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index a7e3134eab0..8f3a07d423d 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -335,8 +335,8 @@ Known caveats: - This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script. - Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container - :obj:`FullyShardedDataParallelism` of fairscale. This is not done automatically by any of the example scripts of the - :class:`~transformers.Trainer`. + :obj:`FullyShardedDataParallelism` of fairscale. It should be used with the option :obj:`auto_wrap` if you are not + doing this yourself: :obj:`--sharded_ddp "zero_dp_3 auto_wrap"`. DeepSpeed diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7e2df0bf551..7a0ab029faa 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -144,6 +144,7 @@ if is_fairscale_available(): if version.parse(fairscale.__version__) >= version.parse("0.3"): from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP + from fairscale.nn.wrap import auto_wrap else: FullyShardedDDP = None @@ -775,8 +776,13 @@ class Trainer: cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 # XXX: Breaking the self.model convention but I see no way around it for now. + if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: + model = auto_wrap(model) self.model = model = FullyShardedDDP( - model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload + model, + mixed_precision=mixed_precision, + reshard_after_forward=zero_3, + cpu_offload=cpu_offload, ).to(self.args.device) elif is_sagemaker_distributed_available(): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 5d7deed2e80..0df6eba5444 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -446,3 +446,4 @@ class ShardedDDPOption(ExplicitEnum): ZERO_DP_2 = "zero_dp_2" ZERO_DP_3 = "zero_dp_3" OFFLOAD = "offload" + AUTO_WRAP = "auto_wrap" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 133a04c9488..36422a83675 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -470,10 +470,10 @@ class TrainingArguments: sharded_ddp: str = field( default="", metadata={ - "choices": ["simple", "zero_dp_2", "zero_dp_3", "zero_dp_2 offload", "zero_dp_3 offload"], "help": "Whether or not to use sharded DDP training (in distributed training only). The base option " "should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` " - "like this: zero_dp_2 offload` or `zero_dp_3 offload`", + "like this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or " + "with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`.", }, ) deepspeed: Optional[str] = field( @@ -570,7 +570,7 @@ class TrainingArguments: "`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or " '`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.' ) - elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp: + elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp: raise ValueError("`--sharded_ddp simple` is not compatible with any other option.") elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")