mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
FSDP integration enhancements and fixes (#18134)
* FSDP integration enhancements and fixes * resolving comments * fsdp fp16 mixed precision requires `ShardedGradScaler`
This commit is contained in:
parent
8e445ca51d
commit
bc8e30bab9
@ -94,6 +94,7 @@ from .trainer_pt_utils import (
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
find_batch_size,
|
||||
get_module_class_from_name,
|
||||
get_parameter_names,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
@ -400,6 +401,8 @@ class Trainer:
|
||||
self.fsdp = ShardingStrategy.FULL_SHARD
|
||||
elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
|
||||
self.fsdp = ShardingStrategy.SHARD_GRAD_OP
|
||||
elif FSDPOption.NO_SHARD in args.fsdp:
|
||||
self.fsdp = ShardingStrategy.NO_SHARD
|
||||
|
||||
# one place to sort out whether to place the model on device or not
|
||||
# postpone switching model to cuda when:
|
||||
@ -511,11 +514,6 @@ class Trainer:
|
||||
args.fp16 = smp.state.cfg.fp16
|
||||
|
||||
if args.fp16 or args.bf16:
|
||||
if self.fsdp is not None:
|
||||
raise ValueError(
|
||||
"Mixed precision is currently not supported for FSDP."
|
||||
"Please do not set arguments related to `mixed_precision`"
|
||||
)
|
||||
if args.half_precision_backend == "auto":
|
||||
if args.device == torch.device("cpu"):
|
||||
if args.fp16:
|
||||
@ -543,6 +541,16 @@ class Trainer:
|
||||
self.do_grad_scaling = True
|
||||
if self.sharded_ddp is not None:
|
||||
self.scaler = ShardedGradScaler()
|
||||
elif self.fsdp is not None:
|
||||
if self.amp_dtype == torch.float16:
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
|
||||
self.scaler = ShardedGradScaler()
|
||||
else:
|
||||
self.do_grad_scaling = False
|
||||
self.use_cuda_amp = False
|
||||
self.amp_dtype = None
|
||||
|
||||
elif is_torch_tpu_available():
|
||||
from torch_xla.amp import GradScaler
|
||||
|
||||
@ -1316,7 +1324,8 @@ class Trainer:
|
||||
# PyTorch FSDP!
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
|
||||
if FSDPOption.OFFLOAD in self.args.fsdp:
|
||||
cpu_offload = CPUOffload(offload_params=True)
|
||||
@ -1329,11 +1338,31 @@ class Trainer:
|
||||
auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
|
||||
)
|
||||
|
||||
elif self.args.fsdp_transformer_layer_cls_to_wrap is not None:
|
||||
transformer_cls_to_wrap = get_module_class_from_name(
|
||||
model, self.args.fsdp_transformer_layer_cls_to_wrap
|
||||
)
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
# Transformer layer class to wrap
|
||||
transformer_layer_cls={transformer_cls_to_wrap},
|
||||
)
|
||||
mixed_precision_policy = None
|
||||
dtype = None
|
||||
if self.args.fp16:
|
||||
dtype = torch.float16
|
||||
elif self.args.bf16:
|
||||
dtype = torch.bfloat16
|
||||
if dtype is not None:
|
||||
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
|
||||
if type(model) != FSDP:
|
||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||
self.model = model = FSDP(
|
||||
model, sharding_strategy=self.fsdp, cpu_offload=cpu_offload, auto_wrap_policy=auto_wrap_policy
|
||||
model,
|
||||
sharding_strategy=self.fsdp,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
mixed_precision=mixed_precision_policy,
|
||||
)
|
||||
if FSDPOption.OFFLOAD not in self.args.fsdp:
|
||||
model.to(self.args.device)
|
||||
|
@ -1033,6 +1033,26 @@ def get_parameter_names(model, forbidden_layer_types):
|
||||
return result
|
||||
|
||||
|
||||
def get_module_class_from_name(module, name):
|
||||
"""
|
||||
Gets a class from a module by its name.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module to get the class from.
|
||||
name (`str`): The name of the class.
|
||||
"""
|
||||
modules_children = list(module.children())
|
||||
if module.__class__.__name__ == name:
|
||||
return module.__class__
|
||||
elif len(modules_children) == 0:
|
||||
return
|
||||
else:
|
||||
for child_module in modules_children:
|
||||
module_class = get_module_class_from_name(child_module, name)
|
||||
if module_class is not None:
|
||||
return module_class
|
||||
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
@ -653,6 +653,7 @@ def find_executable_batch_size(
|
||||
class FSDPOption(ExplicitEnum):
|
||||
FULL_SHARD = "full_shard"
|
||||
SHARD_GRAD_OP = "shard_grad_op"
|
||||
NO_SHARD = "no_shard"
|
||||
OFFLOAD = "offload"
|
||||
AUTO_WRAP = "auto_wrap"
|
||||
|
||||
|
@ -787,10 +787,10 @@ class TrainingArguments:
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"
|
||||
" only). The base option should be `full_shard` or `shard_grad_op` and you can add CPU-offload to"
|
||||
" `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op offload`. You can"
|
||||
" add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard auto_wrap` or"
|
||||
" `shard_grad_op auto_wrap`."
|
||||
" only). The base option should be `full_shard`, `shard_grad_op` or `no_shard` and you can add"
|
||||
" CPU-offload to `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op"
|
||||
" offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard"
|
||||
" auto_wrap` or `shard_grad_op auto_wrap`."
|
||||
),
|
||||
},
|
||||
)
|
||||
@ -803,6 +803,15 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
fsdp_transformer_layer_cls_to_wrap: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
|
||||
"(useful only when `fsdp` flag is passed).",
|
||||
)
|
||||
},
|
||||
)
|
||||
deepspeed: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -1160,6 +1169,14 @@ class TrainingArguments:
|
||||
if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0:
|
||||
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
|
||||
|
||||
if len(self.fsdp) == 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
|
||||
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
|
||||
|
||||
if len(self.fsdp) > 0 and self.fsdp_min_num_params > 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
|
||||
raise ValueError(
|
||||
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
|
||||
)
|
||||
|
||||
if self.tpu_metrics_debug:
|
||||
warnings.warn(
|
||||
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||
|
Loading…
Reference in New Issue
Block a user