FSDP integration enhancements and fixes (#18134)

* FSDP integration enhancements and fixes

* resolving comments

* fsdp fp16 mixed precision requires `ShardedGradScaler`
This commit is contained in:
Sourab Mangrulkar 2022-07-19 00:02:10 +05:30 committed by GitHub
parent 8e445ca51d
commit bc8e30bab9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 12 deletions

View File

@ -94,6 +94,7 @@ from .trainer_pt_utils import (
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_module_class_from_name,
get_parameter_names,
nested_concat,
nested_detach,
@ -400,6 +401,8 @@ class Trainer:
self.fsdp = ShardingStrategy.FULL_SHARD
elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
self.fsdp = ShardingStrategy.SHARD_GRAD_OP
elif FSDPOption.NO_SHARD in args.fsdp:
self.fsdp = ShardingStrategy.NO_SHARD
# one place to sort out whether to place the model on device or not
# postpone switching model to cuda when:
@ -511,11 +514,6 @@ class Trainer:
args.fp16 = smp.state.cfg.fp16
if args.fp16 or args.bf16:
if self.fsdp is not None:
raise ValueError(
"Mixed precision is currently not supported for FSDP."
"Please do not set arguments related to `mixed_precision`"
)
if args.half_precision_backend == "auto":
if args.device == torch.device("cpu"):
if args.fp16:
@ -543,6 +541,16 @@ class Trainer:
self.do_grad_scaling = True
if self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
elif self.fsdp is not None:
if self.amp_dtype == torch.float16:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
self.scaler = ShardedGradScaler()
else:
self.do_grad_scaling = False
self.use_cuda_amp = False
self.amp_dtype = None
elif is_torch_tpu_available():
from torch_xla.amp import GradScaler
@ -1316,7 +1324,8 @@ class Trainer:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True)
@ -1329,11 +1338,31 @@ class Trainer:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
)
elif self.args.fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = get_module_class_from_name(
model, self.args.fsdp_transformer_layer_cls_to_wrap
)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls={transformer_cls_to_wrap},
)
mixed_precision_policy = None
dtype = None
if self.args.fp16:
dtype = torch.float16
elif self.args.bf16:
dtype = torch.bfloat16
if dtype is not None:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
self.model = model = FSDP(
model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy
model,
sharding_strategy=self.fsdp,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
)
if FSDPOption.OFFLOAD not in self.args.fsdp:
model.to(self.args.device)

View File

@ -1033,6 +1033,26 @@ def get_parameter_names(model, forbidden_layer_types):
return result
def get_module_class_from_name(module, name):
"""
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp

View File

@ -653,6 +653,7 @@ def find_executable_batch_size(
class FSDPOption(ExplicitEnum):
FULL_SHARD = "full_shard"
SHARD_GRAD_OP = "shard_grad_op"
NO_SHARD = "no_shard"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"

View File

@ -787,10 +787,10 @@ class TrainingArguments:
metadata={
"help": (
"Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"
" only). The base option should be `full_shard` or `shard_grad_op` and you can add CPU-offload to"
" `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op offload`. You can"
" add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard auto_wrap` or"
" `shard_grad_op auto_wrap`."
" only). The base option should be `full_shard`, `shard_grad_op` or `no_shard` and you can add"
" CPU-offload to `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op"
" offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard"
" auto_wrap` or `shard_grad_op auto_wrap`."
),
},
)
@ -803,6 +803,15 @@ class TrainingArguments:
)
},
)
fsdp_transformer_layer_cls_to_wrap: str = field(
default=None,
metadata={
"help": (
"Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
"(useful only when `fsdp` flag is passed).",
)
},
)
deepspeed: Optional[str] = field(
default=None,
metadata={
@ -1160,6 +1169,14 @@ class TrainingArguments:
if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
if len(self.fsdp) > 0 and self.fsdp_min_num_params > 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
raise ValueError(
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
)
if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"