move fsdp handling to accelerate (#23158)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments

* move fsdp handling to accelerate

* fixex

* fix saving
This commit is contained in:
Sourab Mangrulkar 2023-05-31 14:10:46 +05:30 committed by GitHub
parent 015829e6c4
commit 0b774074a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 105 additions and 116 deletions

View File

@ -343,6 +343,12 @@ class Trainer:
# create accelerator object
self.accelerator = Accelerator()
# post accelerator creation setup
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start()
@ -464,7 +470,7 @@ class Trainer:
self.fsdp = ShardingStrategy.NO_SHARD
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" in self.args.fsdp_config.get(
if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
"backward_prefetch", []
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
@ -1479,114 +1485,58 @@ class Trainer:
cpu_offload=cpu_offload,
).to(self.args.device)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None:
if not self.args.fsdp_config["xla"]:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = CPUOffload(offload_params=False)
auto_wrap_policy = None
if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
mixed_precision_policy = None
dtype = None
if self.args.fp16:
dtype = torch.float16
elif self.args.bf16:
dtype = torch.bfloat16
if dtype is not None:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
signature = inspect.signature(FSDP.__init__).parameters.keys()
kwargs = {}
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
if arg in signature:
kwargs[arg] = getattr(self, arg)
self.model = model = FSDP(
model,
sharding_strategy=self.fsdp,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=self.args.device,
**kwargs,
)
else:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)
# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
elif self.fsdp is not None and self.args.fsdp_config["xla"]:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)
# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss
# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)
xm.optimizer_step = patched_optimizer_step
# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss
xm.optimizer_step = patched_optimizer_step
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
@ -1796,17 +1746,26 @@ class Trainer:
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
if use_accelerator_prepare:
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
self.model = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
@ -2894,11 +2853,15 @@ class Trainer:
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
or self.fsdp is not None
or getattr(self.accelerator.state, "fsdp_plugin", None) is not None
):
state_dict = self.model.state_dict()
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.deepspeed:
# this takes care of everything as long as we aren't under zero3
if self.args.should_save:

View File

@ -442,7 +442,7 @@ class TrainingArguments:
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
gradient
computation.
- `"backward_pos"` : This prefetches the next set of parameters after the current set of
- `"backward_post"` : This prefetches the next set of parameters after the current set of
parameters
gradient computation.
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
@ -1504,6 +1504,32 @@ class TrainingArguments:
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"
from accelerate.utils.constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_SHARDING_STRATEGY,
)
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
)
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"