mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
move fsdp handling to accelerate (#23158)
* mixed precision support via accelerate
* fix issues
* fix for the sharded ddp case
* fix flax and tf failing tests
* `refactor the place to create `Accelerator` object
* move ddp prep to accelerate
* fix 😅
* resolving comments
* move fsdp handling to accelerate
* fixex
* fix saving
This commit is contained in:
parent
015829e6c4
commit
0b774074a5
@ -343,6 +343,12 @@ class Trainer:
|
||||
# create accelerator object
|
||||
self.accelerator = Accelerator()
|
||||
|
||||
# post accelerator creation setup
|
||||
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||
fsdp_plugin = self.accelerator.state.fsdp_plugin
|
||||
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
|
||||
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
|
||||
|
||||
# memory metrics - must set up as early as possible
|
||||
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||
self._memory_tracker.start()
|
||||
@ -464,7 +470,7 @@ class Trainer:
|
||||
self.fsdp = ShardingStrategy.NO_SHARD
|
||||
|
||||
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
|
||||
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" in self.args.fsdp_config.get(
|
||||
if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
|
||||
"backward_prefetch", []
|
||||
):
|
||||
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
|
||||
@ -1479,114 +1485,58 @@ class Trainer:
|
||||
cpu_offload=cpu_offload,
|
||||
).to(self.args.device)
|
||||
# Distributed training using PyTorch FSDP
|
||||
elif self.fsdp is not None:
|
||||
if not self.args.fsdp_config["xla"]:
|
||||
# PyTorch FSDP!
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
|
||||
if FSDPOption.OFFLOAD in self.args.fsdp:
|
||||
cpu_offload = CPUOffload(offload_params=True)
|
||||
else:
|
||||
cpu_offload = CPUOffload(offload_params=False)
|
||||
|
||||
auto_wrap_policy = None
|
||||
|
||||
if FSDPOption.AUTO_WRAP in self.args.fsdp:
|
||||
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||
)
|
||||
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||
transformer_cls_to_wrap = set()
|
||||
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
||||
transformer_cls = get_module_class_from_name(model, layer_class)
|
||||
if transformer_cls is None:
|
||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||
else:
|
||||
transformer_cls_to_wrap.add(transformer_cls)
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
# Transformer layer class to wrap
|
||||
transformer_layer_cls=transformer_cls_to_wrap,
|
||||
)
|
||||
mixed_precision_policy = None
|
||||
dtype = None
|
||||
if self.args.fp16:
|
||||
dtype = torch.float16
|
||||
elif self.args.bf16:
|
||||
dtype = torch.bfloat16
|
||||
if dtype is not None:
|
||||
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
|
||||
if type(model) != FSDP:
|
||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||
signature = inspect.signature(FSDP.__init__).parameters.keys()
|
||||
kwargs = {}
|
||||
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
|
||||
if arg in signature:
|
||||
kwargs[arg] = getattr(self, arg)
|
||||
self.model = model = FSDP(
|
||||
model,
|
||||
sharding_strategy=self.fsdp,
|
||||
cpu_offload=cpu_offload,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
mixed_precision=mixed_precision_policy,
|
||||
device_id=self.args.device,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
|
||||
from torch_xla.distributed.fsdp import checkpoint_module
|
||||
from torch_xla.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
|
||||
auto_wrap_policy = None
|
||||
auto_wrapper_callable = None
|
||||
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||
)
|
||||
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||
transformer_cls_to_wrap = set()
|
||||
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
||||
transformer_cls = get_module_class_from_name(model, layer_class)
|
||||
if transformer_cls is None:
|
||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||
else:
|
||||
transformer_cls_to_wrap.add(transformer_cls)
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
# Transformer layer class to wrap
|
||||
transformer_layer_cls=transformer_cls_to_wrap,
|
||||
)
|
||||
fsdp_kwargs = self.args.xla_fsdp_config
|
||||
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
|
||||
def auto_wrapper_callable(m, *args, **kwargs):
|
||||
return FSDP(checkpoint_module(m), *args, **kwargs)
|
||||
|
||||
# Wrap the base model with an outer FSDP wrapper
|
||||
self.model = model = FSDP(
|
||||
model,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
auto_wrapper_callable=auto_wrapper_callable,
|
||||
**fsdp_kwargs,
|
||||
elif self.fsdp is not None and self.args.fsdp_config["xla"]:
|
||||
try:
|
||||
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
|
||||
from torch_xla.distributed.fsdp import checkpoint_module
|
||||
from torch_xla.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
|
||||
auto_wrap_policy = None
|
||||
auto_wrapper_callable = None
|
||||
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
auto_wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||
)
|
||||
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||
transformer_cls_to_wrap = set()
|
||||
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
||||
transformer_cls = get_module_class_from_name(model, layer_class)
|
||||
if transformer_cls is None:
|
||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||
else:
|
||||
transformer_cls_to_wrap.add(transformer_cls)
|
||||
auto_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
# Transformer layer class to wrap
|
||||
transformer_layer_cls=transformer_cls_to_wrap,
|
||||
)
|
||||
fsdp_kwargs = self.args.xla_fsdp_config
|
||||
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
|
||||
def auto_wrapper_callable(m, *args, **kwargs):
|
||||
return FSDP(checkpoint_module(m), *args, **kwargs)
|
||||
|
||||
# Patch `xm.optimizer_step` should not reduce gradients in this case,
|
||||
# as FSDP does not need gradient reduction over sharded parameters.
|
||||
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
|
||||
loss = optimizer.step(**optimizer_args)
|
||||
if barrier:
|
||||
xm.mark_step()
|
||||
return loss
|
||||
# Wrap the base model with an outer FSDP wrapper
|
||||
self.model = model = FSDP(
|
||||
model,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
auto_wrapper_callable=auto_wrapper_callable,
|
||||
**fsdp_kwargs,
|
||||
)
|
||||
|
||||
xm.optimizer_step = patched_optimizer_step
|
||||
# Patch `xm.optimizer_step` should not reduce gradients in this case,
|
||||
# as FSDP does not need gradient reduction over sharded parameters.
|
||||
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
|
||||
loss = optimizer.step(**optimizer_args)
|
||||
if barrier:
|
||||
xm.mark_step()
|
||||
return loss
|
||||
|
||||
xm.optimizer_step = patched_optimizer_step
|
||||
elif is_sagemaker_dp_enabled():
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
|
||||
@ -1796,17 +1746,26 @@ class Trainer:
|
||||
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
|
||||
self._load_from_checkpoint(resume_from_checkpoint, model)
|
||||
|
||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||
if model is not self.model:
|
||||
self.model_wrapped = model
|
||||
# as the model is wrapped, don't use `accelerator.prepare`
|
||||
# this is for unhandled cases such as
|
||||
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
||||
use_accelerator_prepare = True if model is self.model else False
|
||||
|
||||
if delay_optimizer_creation:
|
||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||
|
||||
# prepare using `accelerator` prepare
|
||||
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||
self.model, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
if use_accelerator_prepare:
|
||||
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||
self.model, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
|
||||
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||
self.model = model
|
||||
|
||||
# for the rest of this function `model` is the outside model, whether it was wrapped or not
|
||||
if model is not self.model:
|
||||
self.model_wrapped = model
|
||||
|
||||
# Check if saved optimizer or scheduler states exist
|
||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||
@ -2894,11 +2853,15 @@ class Trainer:
|
||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
|
||||
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||
or self.fsdp is not None
|
||||
or getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
):
|
||||
state_dict = self.model.state_dict()
|
||||
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
|
||||
else:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
elif self.deepspeed:
|
||||
# this takes care of everything as long as we aren't under zero3
|
||||
if self.args.should_save:
|
||||
|
@ -442,7 +442,7 @@ class TrainingArguments:
|
||||
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
|
||||
gradient
|
||||
computation.
|
||||
- `"backward_pos"` : This prefetches the next set of parameters after the current set of
|
||||
- `"backward_post"` : This prefetches the next set of parameters after the current set of
|
||||
parameter’s
|
||||
gradient computation.
|
||||
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
|
||||
@ -1504,6 +1504,32 @@ class TrainingArguments:
|
||||
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
|
||||
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
|
||||
|
||||
# accelerate integration for FSDP
|
||||
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
|
||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||
from accelerate.utils.constants import (
|
||||
FSDP_AUTO_WRAP_POLICY,
|
||||
FSDP_SHARDING_STRATEGY,
|
||||
)
|
||||
|
||||
for fsdp_option in self.fsdp:
|
||||
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
|
||||
# set environment variable for FSDP sharding strategy
|
||||
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
|
||||
elif fsdp_option == FSDPOption.OFFLOAD:
|
||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||
elif fsdp_option == FSDPOption.AUTO_WRAP:
|
||||
if self.fsdp_config["fsdp_min_num_params"] > 0:
|
||||
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
|
||||
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
||||
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
|
||||
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
||||
)
|
||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
||||
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
|
||||
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||
|
||||
if self.tpu_metrics_debug:
|
||||
warnings.warn(
|
||||
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||
|
Loading…
Reference in New Issue
Block a user