mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
remove SharedDDP as it is deprecated (#25702)
* remove SharedDDP as it was drepracated * apply review suggestion * make style * Oops,forgot to remove the compute_loss context manager in Seq2SeqTrainer. * remove the unnecessary conditional statement * keep the logic of IPEX * clean code * mix precision setup & make fixup --------- Co-authored-by: statelesshz <jihuazhong1@huawei.com>
This commit is contained in:
parent
e840aa67e8
commit
27597fea07
@ -19,7 +19,6 @@ from torch import nn
|
||||
from torch.utils.data import DistributedSampler, RandomSampler
|
||||
|
||||
from transformers import PreTrainedModel, Trainer, logging
|
||||
from transformers.integrations import is_fairscale_available
|
||||
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
|
||||
from transformers.optimization import (
|
||||
Adafactor,
|
||||
@ -36,10 +35,6 @@ from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_tpu_available
|
||||
|
||||
|
||||
if is_fairscale_available():
|
||||
from fairscale.optim import OSS
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
arg_to_scheduler = {
|
||||
@ -118,14 +113,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
"eps": self.args.adam_epsilon,
|
||||
}
|
||||
optimizer_kwargs["lr"] = self.args.learning_rate
|
||||
if self.sharded_ddp:
|
||||
self.optimizer = OSS(
|
||||
params=optimizer_grouped_parameters,
|
||||
optim=optimizer_cls,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
|
||||
if self.lr_scheduler is None:
|
||||
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
|
||||
|
2
setup.py
2
setup.py
@ -109,7 +109,6 @@ _deps = [
|
||||
"diffusers",
|
||||
"dill<0.3.5",
|
||||
"evaluate>=0.2.0",
|
||||
"fairscale>0.3",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"filelock",
|
||||
@ -275,7 +274,6 @@ extras["modelcreation"] = deps_list("cookiecutter")
|
||||
|
||||
extras["sagemaker"] = deps_list("sagemaker")
|
||||
extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
|
||||
extras["fairscale"] = deps_list("fairscale")
|
||||
extras["optuna"] = deps_list("optuna")
|
||||
extras["ray"] = deps_list("ray[tune]")
|
||||
extras["sigopt"] = deps_list("sigopt")
|
||||
|
@ -16,7 +16,6 @@ deps = {
|
||||
"diffusers": "diffusers",
|
||||
"dill": "dill<0.3.5",
|
||||
"evaluate": "evaluate>=0.2.0",
|
||||
"fairscale": "fairscale>0.3",
|
||||
"faiss-cpu": "faiss-cpu",
|
||||
"fastapi": "fastapi",
|
||||
"filelock": "filelock",
|
||||
|
@ -57,7 +57,6 @@ _import_structure = {
|
||||
"is_codecarbon_available",
|
||||
"is_comet_available",
|
||||
"is_dagshub_available",
|
||||
"is_fairscale_available",
|
||||
"is_flyte_deck_standard_available",
|
||||
"is_flytekit_available",
|
||||
"is_mlflow_available",
|
||||
@ -118,7 +117,6 @@ if TYPE_CHECKING:
|
||||
is_codecarbon_available,
|
||||
is_comet_available,
|
||||
is_dagshub_available,
|
||||
is_fairscale_available,
|
||||
is_flyte_deck_standard_available,
|
||||
is_flytekit_available,
|
||||
is_mlflow_available,
|
||||
|
@ -134,10 +134,6 @@ def is_dagshub_available():
|
||||
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
|
||||
|
||||
|
||||
def is_fairscale_available():
|
||||
return importlib.util.find_spec("fairscale") is not None
|
||||
|
||||
|
||||
def is_neptune_available():
|
||||
return _has_neptune
|
||||
|
||||
|
@ -42,7 +42,6 @@ from transformers import logging as transformers_logging
|
||||
|
||||
from .integrations import (
|
||||
is_clearml_available,
|
||||
is_fairscale_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_sigopt_available,
|
||||
@ -871,13 +870,6 @@ def require_deepspeed(test_case):
|
||||
return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
|
||||
|
||||
|
||||
def require_fairscale(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires fairscale
|
||||
"""
|
||||
return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)
|
||||
|
||||
|
||||
def require_apex(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires apex
|
||||
|
@ -40,7 +40,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
||||
from .integrations import (
|
||||
get_reporting_integration_callbacks,
|
||||
hp_params,
|
||||
is_fairscale_available,
|
||||
)
|
||||
|
||||
# isort: on
|
||||
@ -58,7 +57,6 @@ from . import __version__
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||
from .dependency_versions_check import dep_version_check
|
||||
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
||||
from .modelcard import TrainingSummary
|
||||
@ -107,7 +105,6 @@ from .trainer_utils import (
|
||||
IntervalStrategy,
|
||||
PredictionOutput,
|
||||
RemoveColumnsCollator,
|
||||
ShardedDDPOption,
|
||||
TrainerMemoryTracker,
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
@ -171,15 +168,6 @@ if is_torch_tpu_available(check_device=False):
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
if is_fairscale_available():
|
||||
dep_version_check("fairscale")
|
||||
import fairscale
|
||||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
|
||||
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
||||
from fairscale.nn.wrap import auto_wrap
|
||||
from fairscale.optim import OSS
|
||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
@ -420,33 +408,6 @@ class Trainer:
|
||||
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
|
||||
)
|
||||
|
||||
# Setup Sharded DDP training
|
||||
self.sharded_ddp = None
|
||||
if len(args.sharded_ddp) > 0:
|
||||
if self.is_deepspeed_enabled:
|
||||
raise ValueError(
|
||||
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
||||
)
|
||||
if len(args.fsdp) > 0:
|
||||
raise ValueError(
|
||||
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
|
||||
)
|
||||
if args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
||||
elif not is_fairscale_available():
|
||||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
||||
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
|
||||
raise ImportError(
|
||||
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
|
||||
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
|
||||
)
|
||||
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.SIMPLE
|
||||
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
|
||||
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
|
||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
|
||||
|
||||
self.fsdp = None
|
||||
if len(args.fsdp) > 0:
|
||||
if self.is_deepspeed_enabled:
|
||||
@ -488,14 +449,12 @@ class Trainer:
|
||||
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
||||
# and we only use deepspeed for training at the moment
|
||||
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
|
||||
# 4. Sharded DDP - same as MP
|
||||
# 5. FSDP - same as MP
|
||||
# 4. FSDP - same as MP
|
||||
self.place_model_on_device = args.place_model_on_device
|
||||
if (
|
||||
self.is_model_parallel
|
||||
or self.is_deepspeed_enabled
|
||||
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
|
||||
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
||||
or (self.fsdp is not None)
|
||||
or self.is_fsdp_enabled
|
||||
):
|
||||
@ -545,11 +504,11 @@ class Trainer:
|
||||
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
||||
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
||||
)
|
||||
if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
|
||||
if (self.is_deepspeed_enabled or (self.fsdp is not None)) and (
|
||||
self.optimizer is not None or self.lr_scheduler is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
|
||||
"Passing `optimizers` is not allowed if Deepspeed or PyTorch FSDP is enabled."
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||
@ -592,7 +551,6 @@ class Trainer:
|
||||
|
||||
# Mixed precision setup
|
||||
self.use_apex = False
|
||||
self.use_cuda_amp = False
|
||||
self.use_cpu_amp = False
|
||||
|
||||
# Mixed precision setup for SageMaker Model Parallel
|
||||
@ -617,33 +575,19 @@ class Trainer:
|
||||
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
|
||||
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
|
||||
)
|
||||
|
||||
if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
|
||||
if args.half_precision_backend == "auto":
|
||||
if args.device == torch.device("cpu"):
|
||||
if args.fp16:
|
||||
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
|
||||
else:
|
||||
args.half_precision_backend = "cpu_amp"
|
||||
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
|
||||
if args.device == torch.device("cpu"):
|
||||
if args.fp16:
|
||||
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
|
||||
else:
|
||||
args.half_precision_backend = "cuda_amp"
|
||||
|
||||
args.half_precision_backend = "cpu_amp"
|
||||
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||
|
||||
self.do_grad_scaling = False
|
||||
if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
|
||||
# deepspeed and SageMaker Model Parallel manage their own half precision
|
||||
if self.sharded_ddp is not None:
|
||||
if args.half_precision_backend == "cuda_amp":
|
||||
self.use_cuda_amp = True
|
||||
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
||||
# bf16 does not need grad scaling
|
||||
self.do_grad_scaling = self.amp_dtype == torch.float16
|
||||
if self.do_grad_scaling:
|
||||
self.scaler = ShardedGradScaler()
|
||||
elif args.half_precision_backend == "cpu_amp":
|
||||
self.use_cpu_amp = True
|
||||
self.amp_dtype = torch.bfloat16
|
||||
if args.half_precision_backend == "cpu_amp":
|
||||
self.use_cpu_amp = True
|
||||
self.amp_dtype = torch.bfloat16
|
||||
elif args.half_precision_backend == "apex":
|
||||
if not is_apex_available():
|
||||
raise ImportError(
|
||||
@ -652,18 +596,6 @@ class Trainer:
|
||||
)
|
||||
self.use_apex = True
|
||||
|
||||
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
|
||||
if (
|
||||
is_sagemaker_mp_enabled()
|
||||
and self.use_cuda_amp
|
||||
and args.max_grad_norm is not None
|
||||
and args.max_grad_norm > 0
|
||||
):
|
||||
raise ValueError(
|
||||
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
|
||||
"along 'max_grad_norm': 0 in your hyperparameters."
|
||||
)
|
||||
|
||||
# Label smoothing
|
||||
if self.args.label_smoothing_factor != 0:
|
||||
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
|
||||
@ -994,27 +926,20 @@ class Trainer:
|
||||
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
||||
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
self.optimizer = OSS(
|
||||
params=optimizer_grouped_parameters,
|
||||
optim=optimizer_cls,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
||||
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
||||
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
logger.info(f"skipped: {skipped/2**20}M params")
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
||||
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
||||
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
logger.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||
@ -1333,7 +1258,6 @@ class Trainer:
|
||||
jit_model(**example_batch)
|
||||
model = jit_model
|
||||
self.use_cpu_amp = False
|
||||
self.use_cuda_amp = False
|
||||
except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
|
||||
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
|
||||
|
||||
@ -1396,25 +1320,8 @@ class Trainer:
|
||||
return model
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if self.sharded_ddp is not None:
|
||||
# Sharded DDP!
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
model = ShardedDDP(model, self.optimizer)
|
||||
else:
|
||||
mixed_precision = self.args.fp16 or self.args.bf16
|
||||
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
|
||||
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
|
||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
|
||||
model = auto_wrap(model)
|
||||
self.model = model = FullyShardedDDP(
|
||||
model,
|
||||
mixed_precision=mixed_precision,
|
||||
reshard_after_forward=zero_3,
|
||||
cpu_offload=cpu_offload,
|
||||
).to(self.args.device)
|
||||
# Distributed training using PyTorch FSDP
|
||||
elif self.fsdp is not None and self.args.fsdp_config["xla"]:
|
||||
if self.fsdp is not None and self.args.fsdp_config["xla"]:
|
||||
try:
|
||||
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
|
||||
from torch_xla.distributed.fsdp import checkpoint_module
|
||||
@ -1669,13 +1576,7 @@ class Trainer:
|
||||
else:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
|
||||
delay_optimizer_creation = (
|
||||
self.sharded_ddp is not None
|
||||
and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||
or is_sagemaker_mp_enabled()
|
||||
or self.fsdp is not None
|
||||
or self.is_fsdp_enabled
|
||||
)
|
||||
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled
|
||||
|
||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||
if self._created_lr_scheduler:
|
||||
@ -1716,7 +1617,7 @@ class Trainer:
|
||||
|
||||
# as the model is wrapped, don't use `accelerator.prepare`
|
||||
# this is for unhandled cases such as
|
||||
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
||||
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
||||
use_accelerator_prepare = True if model is self.model else False
|
||||
|
||||
if delay_optimizer_creation:
|
||||
@ -1932,14 +1833,6 @@ class Trainer:
|
||||
if args.max_grad_norm is not None and args.max_grad_norm > 0:
|
||||
# deepspeed does its own clipping
|
||||
|
||||
if self.do_grad_scaling:
|
||||
# Reduce gradients first for XLA
|
||||
if is_torch_tpu_available():
|
||||
gradients = xm._fetch_gradients(self.optimizer)
|
||||
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
|
||||
# AMP: gradients need unscaling
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
|
||||
if is_sagemaker_mp_enabled() and args.fp16:
|
||||
self.optimizer.clip_master_grads(args.max_grad_norm)
|
||||
elif hasattr(self.optimizer, "clip_grad_norm"):
|
||||
@ -1961,24 +1854,8 @@ class Trainer:
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
optimizer_was_run = True
|
||||
if is_torch_tpu_available():
|
||||
if self.do_grad_scaling:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# tpu-comment: accelerate wrapped optimizers call xm.optimizer_step
|
||||
self.optimizer.step()
|
||||
elif self.do_grad_scaling:
|
||||
scale_before = self.scaler.get_scale()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
scale_after = self.scaler.get_scale()
|
||||
optimizer_was_run = scale_before <= scale_after
|
||||
else:
|
||||
self.optimizer.step()
|
||||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
||||
|
||||
self.optimizer.step()
|
||||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
||||
if optimizer_was_run:
|
||||
# Delay optimizer scheduling until metrics are generated
|
||||
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
@ -2408,9 +2285,6 @@ class Trainer:
|
||||
self.model_wrapped.save_checkpoint(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
self.optimizer.consolidate_state_dict()
|
||||
|
||||
if self.fsdp or self.is_fsdp_enabled:
|
||||
if self.is_fsdp_enabled:
|
||||
save_fsdp_optimizer(
|
||||
@ -2455,8 +2329,6 @@ class Trainer:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.do_grad_scaling:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||
@ -2600,8 +2472,6 @@ class Trainer:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
|
||||
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
||||
|
||||
def hyperparameter_search(
|
||||
self,
|
||||
@ -2744,12 +2614,8 @@ class Trainer:
|
||||
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
|
||||
arguments, depending on the situation.
|
||||
"""
|
||||
if self.use_cuda_amp or self.use_cpu_amp:
|
||||
ctx_manager = (
|
||||
torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
|
||||
if self.use_cpu_amp
|
||||
else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
|
||||
)
|
||||
if self.use_cpu_amp:
|
||||
ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
|
||||
else:
|
||||
ctx_manager = contextlib.nullcontext()
|
||||
|
||||
@ -2786,9 +2652,7 @@ class Trainer:
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
|
||||
if self.do_grad_scaling:
|
||||
self.scaler.scale(loss).backward()
|
||||
elif self.use_apex:
|
||||
if self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
@ -2872,12 +2736,7 @@ class Trainer:
|
||||
if IS_SAGEMAKER_MP_POST_1_10:
|
||||
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
|
||||
Path(os.path.join(output_dir, "user_content.pt")).touch()
|
||||
elif (
|
||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
|
||||
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||
or self.fsdp is not None
|
||||
or self.is_fsdp_enabled
|
||||
):
|
||||
elif self.fsdp is not None or self.is_fsdp_enabled:
|
||||
state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {}
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
|
@ -266,7 +266,6 @@ class Seq2SeqTrainer(Trainer):
|
||||
has_labels = "labels" in inputs
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
# XXX: adapt synced_gpus for fairscale as well
|
||||
# Priority (handled in generate):
|
||||
# non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
|
||||
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
|
||||
|
@ -651,14 +651,6 @@ def number_of_arguments(func):
|
||||
return len(inspect.signature(func).parameters)
|
||||
|
||||
|
||||
class ShardedDDPOption(ExplicitEnum):
|
||||
SIMPLE = "simple"
|
||||
ZERO_DP_2 = "zero_dp_2"
|
||||
ZERO_DP_3 = "zero_dp_3"
|
||||
OFFLOAD = "offload"
|
||||
AUTO_WRAP = "auto_wrap"
|
||||
|
||||
|
||||
def find_executable_batch_size(
|
||||
function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
|
||||
):
|
||||
|
@ -34,7 +34,6 @@ from .trainer_utils import (
|
||||
HubStrategy,
|
||||
IntervalStrategy,
|
||||
SchedulerType,
|
||||
ShardedDDPOption,
|
||||
)
|
||||
from .utils import (
|
||||
ExplicitEnum,
|
||||
@ -328,9 +327,9 @@ class TrainingArguments:
|
||||
fp16_backend (`str`, *optional*, defaults to `"auto"`):
|
||||
This argument is deprecated. Use `half_precision_backend` instead.
|
||||
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
|
||||
The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`.
|
||||
`"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices
|
||||
will force the requested backend.
|
||||
The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
|
||||
use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
|
||||
requested backend.
|
||||
bf16_full_eval (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
|
||||
metric values. This is an experimental API and it may change.
|
||||
@ -410,21 +409,6 @@ class TrainingArguments:
|
||||
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
||||
stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
|
||||
can take a long time) but will not yield the same results as the interrupted training would have.
|
||||
sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `''`):
|
||||
Use Sharded DDP training from [FairScale](https://github.com/facebookresearch/fairscale) (in distributed
|
||||
training only). This is an experimental feature.
|
||||
|
||||
A list of options along the following:
|
||||
|
||||
- `"simple"`: to use first instance of sharded DDP released by fairscale (`ShardedDDP`) similar to ZeRO-2.
|
||||
- `"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in
|
||||
Zero-2 mode (with `reshard_after_forward=False`).
|
||||
- `"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in
|
||||
Zero-3 mode (with `reshard_after_forward=True`).
|
||||
- `"offload"`: to add ZeRO-offload (only compatible with `"zero_dp_2"` and `"zero_dp_3"`).
|
||||
|
||||
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
|
||||
list for `False` and `["simple"]` for `True`.
|
||||
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`):
|
||||
Use PyTorch Distributed Parallel Training (in distributed training only).
|
||||
|
||||
@ -877,7 +861,7 @@ class TrainingArguments:
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "The backend to be used for half precision.",
|
||||
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
|
||||
"choices": ["auto", "apex", "cpu_amp"],
|
||||
},
|
||||
)
|
||||
bf16_full_eval: bool = field(
|
||||
@ -996,17 +980,6 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
sharded_ddp: Optional[Union[List[ShardedDDPOption], str]] = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether or not to use sharded DDP training (in distributed training only). The base option should be"
|
||||
" `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` like"
|
||||
" this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3`"
|
||||
" with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`."
|
||||
),
|
||||
},
|
||||
)
|
||||
fsdp: Optional[Union[List[FSDPOption], str]] = field(
|
||||
default="",
|
||||
metadata={
|
||||
@ -1154,7 +1127,7 @@ class TrainingArguments:
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Deprecated. Use half_precision_backend instead",
|
||||
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
|
||||
"choices": ["auto", "apex", "cpu_amp"],
|
||||
},
|
||||
)
|
||||
push_to_hub_model_id: Optional[str] = field(
|
||||
@ -1407,8 +1380,6 @@ class TrainingArguments:
|
||||
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
|
||||
" `--half_precision_backend cuda_amp` instead"
|
||||
)
|
||||
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
||||
raise ValueError("sharded_ddp is not supported with bf16")
|
||||
|
||||
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
|
||||
if self.evaluation_strategy == IntervalStrategy.NO:
|
||||
@ -1508,7 +1479,7 @@ class TrainingArguments:
|
||||
# no need to assert on else
|
||||
|
||||
# if training args is specified, it will override the one specified in the accelerate config
|
||||
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
|
||||
if self.half_precision_backend != "apex":
|
||||
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
|
||||
if self.fp16:
|
||||
mixed_precision_dtype = "fp16"
|
||||
@ -1541,26 +1512,6 @@ class TrainingArguments:
|
||||
" during training"
|
||||
)
|
||||
|
||||
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
||||
warnings.warn(
|
||||
"using `sharded_ddp` is deprecated and will be removed in version 4.33"
|
||||
" of 🤗 Transformers. Use `fsdp` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
if isinstance(self.sharded_ddp, bool):
|
||||
self.sharded_ddp = "simple" if self.sharded_ddp else ""
|
||||
if isinstance(self.sharded_ddp, str):
|
||||
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
|
||||
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
|
||||
raise ValueError(
|
||||
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
|
||||
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
|
||||
)
|
||||
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
|
||||
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
|
||||
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
||||
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
||||
|
||||
if isinstance(self.fsdp, bool):
|
||||
self.fsdp = "full_shard" if self.fsdp else ""
|
||||
if isinstance(self.fsdp, str):
|
||||
|
@ -16,7 +16,6 @@ import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
@ -32,7 +31,6 @@ from transformers.testing_utils import (
|
||||
get_torch_dist_unique_port,
|
||||
require_apex,
|
||||
require_bitsandbytes,
|
||||
require_fairscale,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
@ -105,36 +103,6 @@ class TestTrainerExt(TestCasePlus):
|
||||
def test_run_seq2seq_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True)
|
||||
|
||||
# test --sharded_ddp w/o --fp16
|
||||
@unittest.skip("Requires an update of the env running those tests")
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
|
||||
|
||||
# test --sharded_ddp w/ --fp16
|
||||
@unittest.skip("Requires an update of the env running those tests")
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
|
||||
|
||||
# test --sharded_ddp zero_dp_2 w/o --fp16
|
||||
@unittest.skip("Requires an update of the env running those tests")
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_fully_sharded_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
|
||||
|
||||
# test --sharded_ddp zero_dp_2 w/ --fp16
|
||||
@unittest.skip("Requires an update of the env running those tests")
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
||||
self.run_seq2seq_quick(
|
||||
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
|
||||
)
|
||||
|
||||
@require_apex
|
||||
@require_torch_gpu
|
||||
def test_run_seq2seq_apex(self):
|
||||
|
Loading…
Reference in New Issue
Block a user