remove SharedDDP as it is deprecated (#25702)

* remove SharedDDP as it was drepracated

* apply review suggestion

* make style

* Oops,forgot to remove the compute_loss context manager in Seq2SeqTrainer.

* remove the unnecessary conditional statement

* keep the logic of IPEX

* clean code

* mix precision setup & make fixup

---------

Co-authored-by: statelesshz <jihuazhong1@huawei.com>
This commit is contained in:
statelesshz 2023-10-06 22:03:11 +08:00 committed by GitHub
parent e840aa67e8
commit 27597fea07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 39 additions and 299 deletions

View File

@ -19,7 +19,6 @@ from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from transformers import PreTrainedModel, Trainer, logging
from transformers.integrations import is_fairscale_available
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
from transformers.optimization import (
Adafactor,
@ -36,10 +35,6 @@ from transformers.training_args import ParallelMode
from transformers.utils import is_torch_tpu_available
if is_fairscale_available():
from fairscale.optim import OSS
logger = logging.get_logger(__name__)
arg_to_scheduler = {
@ -118,14 +113,7 @@ class Seq2SeqTrainer(Trainer):
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_ddp:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.lr_scheduler is None:
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)

View File

@ -109,7 +109,6 @@ _deps = [
"diffusers",
"dill<0.3.5",
"evaluate>=0.2.0",
"fairscale>0.3",
"faiss-cpu",
"fastapi",
"filelock",
@ -275,7 +274,6 @@ extras["modelcreation"] = deps_list("cookiecutter")
extras["sagemaker"] = deps_list("sagemaker")
extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
extras["fairscale"] = deps_list("fairscale")
extras["optuna"] = deps_list("optuna")
extras["ray"] = deps_list("ray[tune]")
extras["sigopt"] = deps_list("sigopt")

View File

@ -16,7 +16,6 @@ deps = {
"diffusers": "diffusers",
"dill": "dill<0.3.5",
"evaluate": "evaluate>=0.2.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
"fastapi": "fastapi",
"filelock": "filelock",

View File

@ -57,7 +57,6 @@ _import_structure = {
"is_codecarbon_available",
"is_comet_available",
"is_dagshub_available",
"is_fairscale_available",
"is_flyte_deck_standard_available",
"is_flytekit_available",
"is_mlflow_available",
@ -118,7 +117,6 @@ if TYPE_CHECKING:
is_codecarbon_available,
is_comet_available,
is_dagshub_available,
is_fairscale_available,
is_flyte_deck_standard_available,
is_flytekit_available,
is_mlflow_available,

View File

@ -134,10 +134,6 @@ def is_dagshub_available():
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
def is_fairscale_available():
return importlib.util.find_spec("fairscale") is not None
def is_neptune_available():
return _has_neptune

View File

@ -42,7 +42,6 @@ from transformers import logging as transformers_logging
from .integrations import (
is_clearml_available,
is_fairscale_available,
is_optuna_available,
is_ray_available,
is_sigopt_available,
@ -871,13 +870,6 @@ def require_deepspeed(test_case):
return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)
def require_apex(test_case):
"""
Decorator marking a test that requires apex

View File

@ -40,7 +40,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
from .integrations import (
get_reporting_integration_callbacks,
hp_params,
is_fairscale_available,
)
# isort: on
@ -58,7 +57,6 @@ from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .modelcard import TrainingSummary
@ -107,7 +105,6 @@ from .trainer_utils import (
IntervalStrategy,
PredictionOutput,
RemoveColumnsCollator,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
@ -171,15 +168,6 @@ if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
if is_fairscale_available():
dep_version_check("fairscale")
import fairscale
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.nn.wrap import auto_wrap
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
@ -420,33 +408,6 @@ class Trainer:
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
)
# Setup Sharded DDP training
self.sharded_ddp = None
if len(args.sharded_ddp) > 0:
if self.is_deepspeed_enabled:
raise ValueError(
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
if len(args.fsdp) > 0:
raise ValueError(
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
)
if args.parallel_mode != ParallelMode.DISTRIBUTED:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
raise ImportError(
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
)
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.SIMPLE
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
self.fsdp = None
if len(args.fsdp) > 0:
if self.is_deepspeed_enabled:
@ -488,14 +449,12 @@ class Trainer:
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
# 4. Sharded DDP - same as MP
# 5. FSDP - same as MP
# 4. FSDP - same as MP
self.place_model_on_device = args.place_model_on_device
if (
self.is_model_parallel
or self.is_deepspeed_enabled
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
or (self.fsdp is not None)
or self.is_fsdp_enabled
):
@ -545,11 +504,11 @@ class Trainer:
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
)
if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
if (self.is_deepspeed_enabled or (self.fsdp is not None)) and (
self.optimizer is not None or self.lr_scheduler is not None
):
raise RuntimeError(
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
"Passing `optimizers` is not allowed if Deepspeed or PyTorch FSDP is enabled."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
@ -592,7 +551,6 @@ class Trainer:
# Mixed precision setup
self.use_apex = False
self.use_cuda_amp = False
self.use_cpu_amp = False
# Mixed precision setup for SageMaker Model Parallel
@ -617,33 +575,19 @@ class Trainer:
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
)
if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
if args.half_precision_backend == "auto":
if args.device == torch.device("cpu"):
if args.fp16:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
else:
args.half_precision_backend = "cpu_amp"
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
if args.device == torch.device("cpu"):
if args.fp16:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
else:
args.half_precision_backend = "cuda_amp"
args.half_precision_backend = "cpu_amp"
logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
# deepspeed and SageMaker Model Parallel manage their own half precision
if self.sharded_ddp is not None:
if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
# bf16 does not need grad scaling
self.do_grad_scaling = self.amp_dtype == torch.float16
if self.do_grad_scaling:
self.scaler = ShardedGradScaler()
elif args.half_precision_backend == "cpu_amp":
self.use_cpu_amp = True
self.amp_dtype = torch.bfloat16
if args.half_precision_backend == "cpu_amp":
self.use_cpu_amp = True
self.amp_dtype = torch.bfloat16
elif args.half_precision_backend == "apex":
if not is_apex_available():
raise ImportError(
@ -652,18 +596,6 @@ class Trainer:
)
self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if (
is_sagemaker_mp_enabled()
and self.use_cuda_amp
and args.max_grad_norm is not None
and args.max_grad_norm > 0
):
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
)
# Label smoothing
if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
@ -994,27 +926,20 @@ class Trainer:
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
@ -1333,7 +1258,6 @@ class Trainer:
jit_model(**example_batch)
model = jit_model
self.use_cpu_amp = False
self.use_cuda_amp = False
except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
@ -1396,25 +1320,8 @@ class Trainer:
return model
# Distributed training (should be after apex fp16 initialization)
if self.sharded_ddp is not None:
# Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16 or self.args.bf16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
model = auto_wrap(model)
self.model = model = FullyShardedDDP(
model,
mixed_precision=mixed_precision,
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None and self.args.fsdp_config["xla"]:
if self.fsdp is not None and self.args.fsdp_config["xla"]:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
@ -1669,13 +1576,7 @@ class Trainer:
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = (
self.sharded_ddp is not None
and self.sharded_ddp != ShardedDDPOption.SIMPLE
or is_sagemaker_mp_enabled()
or self.fsdp is not None
or self.is_fsdp_enabled
)
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
@ -1716,7 +1617,7 @@ class Trainer:
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False
if delay_optimizer_creation:
@ -1932,14 +1833,6 @@ class Trainer:
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping
if self.do_grad_scaling:
# Reduce gradients first for XLA
if is_torch_tpu_available():
gradients = xm._fetch_gradients(self.optimizer)
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
@ -1961,24 +1854,8 @@ class Trainer:
)
# Optimizer step
optimizer_was_run = True
if is_torch_tpu_available():
if self.do_grad_scaling:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# tpu-comment: accelerate wrapped optimizers call xm.optimizer_step
self.optimizer.step()
elif self.do_grad_scaling:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
optimizer_was_run = scale_before <= scale_after
else:
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
@ -2408,9 +2285,6 @@ class Trainer:
self.model_wrapped.save_checkpoint(output_dir)
# Save optimizer and scheduler
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()
if self.fsdp or self.is_fsdp_enabled:
if self.is_fsdp_enabled:
save_fsdp_optimizer(
@ -2455,8 +2329,6 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
@ -2600,8 +2472,6 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
def hyperparameter_search(
self,
@ -2744,12 +2614,8 @@ class Trainer:
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
"""
if self.use_cuda_amp or self.use_cpu_amp:
ctx_manager = (
torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
if self.use_cpu_amp
else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
)
if self.use_cpu_amp:
ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
else:
ctx_manager = contextlib.nullcontext()
@ -2786,9 +2652,7 @@ class Trainer:
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
elif self.use_apex:
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
@ -2872,12 +2736,7 @@ class Trainer:
if IS_SAGEMAKER_MP_POST_1_10:
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
Path(os.path.join(output_dir, "user_content.pt")).touch()
elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
or self.fsdp is not None
or self.is_fsdp_enabled
):
elif self.fsdp is not None or self.is_fsdp_enabled:
state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {}
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)

View File

@ -266,7 +266,6 @@ class Seq2SeqTrainer(Trainer):
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate):
# non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):

View File

@ -651,14 +651,6 @@ def number_of_arguments(func):
return len(inspect.signature(func).parameters)
class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple"
ZERO_DP_2 = "zero_dp_2"
ZERO_DP_3 = "zero_dp_3"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"
def find_executable_batch_size(
function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
):

View File

@ -34,7 +34,6 @@ from .trainer_utils import (
HubStrategy,
IntervalStrategy,
SchedulerType,
ShardedDDPOption,
)
from .utils import (
ExplicitEnum,
@ -328,9 +327,9 @@ class TrainingArguments:
fp16_backend (`str`, *optional*, defaults to `"auto"`):
This argument is deprecated. Use `half_precision_backend` instead.
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`.
`"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices
will force the requested backend.
The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
requested backend.
bf16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change.
@ -410,21 +409,6 @@ class TrainingArguments:
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
can take a long time) but will not yield the same results as the interrupted training would have.
sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `''`):
Use Sharded DDP training from [FairScale](https://github.com/facebookresearch/fairscale) (in distributed
training only). This is an experimental feature.
A list of options along the following:
- `"simple"`: to use first instance of sharded DDP released by fairscale (`ShardedDDP`) similar to ZeRO-2.
- `"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in
Zero-2 mode (with `reshard_after_forward=False`).
- `"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in
Zero-3 mode (with `reshard_after_forward=True`).
- `"offload"`: to add ZeRO-offload (only compatible with `"zero_dp_2"` and `"zero_dp_3"`).
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
list for `False` and `["simple"]` for `True`.
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`):
Use PyTorch Distributed Parallel Training (in distributed training only).
@ -877,7 +861,7 @@ class TrainingArguments:
default="auto",
metadata={
"help": "The backend to be used for half precision.",
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
"choices": ["auto", "apex", "cpu_amp"],
},
)
bf16_full_eval: bool = field(
@ -996,17 +980,6 @@ class TrainingArguments:
)
},
)
sharded_ddp: Optional[Union[List[ShardedDDPOption], str]] = field(
default="",
metadata={
"help": (
"Whether or not to use sharded DDP training (in distributed training only). The base option should be"
" `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` like"
" this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3`"
" with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`."
),
},
)
fsdp: Optional[Union[List[FSDPOption], str]] = field(
default="",
metadata={
@ -1154,7 +1127,7 @@ class TrainingArguments:
default="auto",
metadata={
"help": "Deprecated. Use half_precision_backend instead",
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
"choices": ["auto", "apex", "cpu_amp"],
},
)
push_to_hub_model_id: Optional[str] = field(
@ -1407,8 +1380,6 @@ class TrainingArguments:
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
" `--half_precision_backend cuda_amp` instead"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
if self.evaluation_strategy == IntervalStrategy.NO:
@ -1508,7 +1479,7 @@ class TrainingArguments:
# no need to assert on else
# if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
if self.half_precision_backend != "apex":
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16:
mixed_precision_dtype = "fp16"
@ -1541,26 +1512,6 @@ class TrainingArguments:
" during training"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
warnings.warn(
"using `sharded_ddp` is deprecated and will be removed in version 4.33"
" of 🤗 Transformers. Use `fsdp` instead",
FutureWarning,
)
if isinstance(self.sharded_ddp, bool):
self.sharded_ddp = "simple" if self.sharded_ddp else ""
if isinstance(self.sharded_ddp, str):
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
raise ValueError(
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else ""
if isinstance(self.fsdp, str):

View File

@ -16,7 +16,6 @@ import math
import os
import re
import sys
import unittest
from pathlib import Path
from typing import Tuple
from unittest.mock import patch
@ -32,7 +31,6 @@ from transformers.testing_utils import (
get_torch_dist_unique_port,
require_apex,
require_bitsandbytes,
require_fairscale,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
@ -105,36 +103,6 @@ class TestTrainerExt(TestCasePlus):
def test_run_seq2seq_ddp(self):
self.run_seq2seq_quick(distributed=True)
# test --sharded_ddp w/o --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
# test --sharded_ddp w/ --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
# test --sharded_ddp zero_dp_2 w/o --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
# test --sharded_ddp zero_dp_2 w/ --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
self.run_seq2seq_quick(
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
)
@require_apex
@require_torch_gpu
def test_run_seq2seq_apex(self):