[Deepspeed] ZeRO-Infinity integration plus config revamp (#11418)

* adding Z-inf

* revamp config process

* up version requirement

* wip

* massive rewrite

* cleanup

* cleanup

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* consistent json commas

* act on suggestions

* leave this feature for 0.3.16

* style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman 2021-04-26 10:40:32 -07:00 committed by GitHub
parent 0661abc545
commit bc2571e61c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 896 additions and 503 deletions

File diff suppressed because it is too large Load Diff

View File

@ -90,7 +90,7 @@ _deps = [
"cookiecutter==1.7.2", "cookiecutter==1.7.2",
"dataclasses", "dataclasses",
"datasets", "datasets",
"deepspeed>=0.3.14", "deepspeed>=0.3.15",
"docutils==0.16.0", "docutils==0.16.0",
"fairscale>0.3", "fairscale>0.3",
"faiss-cpu", "faiss-cpu",

View File

@ -7,7 +7,7 @@ deps = {
"cookiecutter": "cookiecutter==1.7.2", "cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses", "dataclasses": "dataclasses",
"datasets": "datasets", "datasets": "datasets",
"deepspeed": "deepspeed>=0.3.14", "deepspeed": "deepspeed>=0.3.15",
"docutils": "docutils==0.16.0", "docutils": "docutils==0.16.0",
"fairscale": "fairscale>0.3", "fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu", "faiss-cpu": "faiss-cpu",

View File

@ -19,8 +19,8 @@ import io
import json import json
import numbers import numbers
import os import os
import sys
import tempfile import tempfile
import weakref
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
@ -269,74 +269,180 @@ def rewrite_logs(d):
return new_d return new_d
_is_deepspeed_zero3_enabled = None def _is_true(config, key):
if config is None:
return False
return bool(config.get(key))
def _set_if_auto(config, key, val):
if config is None:
return
if config.get(key) == "auto":
config[key] = val
class DeepSpeedConfigHF:
"""
This object contains Deepspeed configuration and can be quickly queried for things like zero stage.
We store a ``weakref`` of this object in the module's global to be able to access the config from areas where the
Trainer is not available (e.g. `from_pretrained` and `_get_resized_embeddings`).
The ``DeepSpeedConfigHF`` object is meant to be created during ``TrainingArguments`` object creation and has the
same lifespan as the latter.
"""
def __init__(self, args):
self.config = None
self.stage = 0
self.offload = False
dep_version_check("deepspeed")
self.config_process(args)
# set global weakref object
deepspeed_config_hf_set(self)
def is_zero2(self):
return self.stage == 2
def is_zero3(self):
return self.stage == 3
def is_offload(self):
return self.offload
def config_process(self, args):
"""
1. load json if the ``args.deepspeed`` is a path
2. replace any ``auto`` values in the config with the correct or recommended value
This is done as early as possible, before model is created, to allow ``is_deepspeed_zero3_enabled`` query and
getting to the early deepspeed config object during ``zero.Init()`` which needs whether fp16 is enabled, dtype,
etc.
"""
config_file_or_dict = args.deepspeed
if isinstance(config_file_or_dict, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since `auto` values would have been overriden
config = deepcopy(config_file_or_dict)
elif isinstance(config_file_or_dict, str):
with io.open(config_file_or_dict, "r", encoding="utf-8") as f:
config = json.load(f)
else:
raise ValueError("expecting either a path to a config file or a pre-populated dict")
self.config = config
# DeepSpeed does:
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
_set_if_auto(config, "train_micro_batch_size_per_gpu", args.per_device_train_batch_size)
_set_if_auto(config, "gradient_accumulation_steps", args.gradient_accumulation_steps)
_set_if_auto(config, "train_batch_size", train_batch_size)
_set_if_auto(config, "gradient_clipping", args.max_grad_norm)
# zero
config_zero = config.get("zero_optimization", {})
self.stage = config_zero.get("stage", 0)
config_optim = config.get("optimizer", {})
if config_optim != {}:
config_optim_params = config_optim.get("params")
_set_if_auto(config_optim_params, "lr", args.learning_rate)
_set_if_auto(config_optim_params, "betas", [args.adam_beta1, args.adam_beta2])
_set_if_auto(config_optim_params, "eps", args.adam_epsilon)
_set_if_auto(config_optim_params, "weight_decay", args.weight_decay)
config_sched = config.get("scheduler", {})
if config_sched != {}:
config_sched_params = config_sched.get("params")
_set_if_auto(config_sched_params, "warmup_min_lr", 0)
_set_if_auto(config_sched_params, "warmup_max_lr", args.learning_rate)
_set_if_auto(config_sched_params, "warmup_num_steps", args.warmup_steps)
# total_num_steps - will get set in deepspeed_init
# fp16
if args.fp16:
fp16_backend = "apex" if args.fp16_backend == "apex" else "amp"
else:
fp16_backend = None
# amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set
# any here unless the user did the work
config_fp16 = config.get("fp16")
# XXX: at the moment fp16 can't be False, but the fp32 solution is in works - once it's PR'ed and
# merged and a new release is made, delete the next line and uncomment the one after it
_set_if_auto(config_fp16, "enabled", True)
# _set_if_auto(config_fp16, "enabled", fp16_backend == "amp")
# apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
# ZeRO features, so probably best to be avoided.
config_amp = config.get("amp")
_set_if_auto(config_amp, "enabled", fp16_backend == "apex")
_set_if_auto(config_amp, "opt_level", args.fp16_opt_level)
config_zero = config.get("zero_optimization", {})
if self.is_zero2():
self.offload = _is_true(config_zero, "cpu_offload")
elif self.is_zero3():
offload_devices = ["cpu", "nvme"]
if config_zero.get("offload_optimizer", {}).get("device") in offload_devices:
self.offload = True
if config_zero.get("offload_param", {}).get("device") in offload_devices:
self.offload = True
def config_finalize(self, args, model, num_training_steps):
"""
This stage is run after we have the model and know num_training_steps.
Now we we can complete the configuration process.
"""
config = self.config
# zero
config_zero = config.get("zero_optimization", {})
if self.is_zero3():
# automatically assign the optimal config values based on model config
hidden_size = model.config.hidden_size
_set_if_auto(config_zero, "reduce_bucket_size", hidden_size * hidden_size)
_set_if_auto(config_zero, "stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size)
_set_if_auto(config_zero, "stage3_param_persistence_threshold", 10 * hidden_size)
# scheduler
config_sched = config.get("scheduler", {})
config_sched_params = config_sched.get("params", {})
_set_if_auto(config_sched_params, "total_num_steps", num_training_steps)
# keep the config object global to be able to access it anywhere during TrainingArguments life-cycle
_deepspeed_config_hf_weak_ref = None
def deepspeed_config_hf_set(deepspeed_config_hf_obj):
# this is a special weakref global object to allow us to get to Deepspeed config from APIs
# that don't have an easy way to get to the Deepspeed config outside of the Trainer domain.
global _deepspeed_config_hf_weak_ref
# will go away automatically when DeepSpeedConfigHF is destroyed (when TrainingArguments is destroyed)
_deepspeed_config_hf_weak_ref = weakref.ref(deepspeed_config_hf_obj)
def is_deepspeed_zero3_enabled(): def is_deepspeed_zero3_enabled():
""" if _deepspeed_config_hf_weak_ref is not None and _deepspeed_config_hf_weak_ref() is not None:
This function answers to the question of whether DeepSpeed is going to be used and run using ZeRO Stage 3. return _deepspeed_config_hf_weak_ref().is_zero3()
It includes an auto-discovery method, see comments in the code for details.
Returns: ``True`` if either it was explicitly enabled via ``deepspeed_zero3_enable(True)`` or the auto-detector was
able to derive that the ``Trainer`` will be running via DeepSpeed ZeRO stage 3.
"""
global _is_deepspeed_zero3_enabled
if _is_deepspeed_zero3_enabled is None:
_is_deepspeed_zero3_enabled = False
# Try to auto-discover if we are about to use DeepSpeed with ZeRO3 enabled. This will only
# work for scripts using cli to pass --deepspeed ds_config.json. If cmd args aren't used,
# then to get the model efficiently loaded across multiple-gpus one has to explicitly call
# is_deepspeed_zero3_enabled(True) **before** instantiating a model object
if "--deepspeed" in sys.argv:
idx = sys.argv.index("--deepspeed")
ds_config = sys.argv[idx + 1]
if not os.path.exists(ds_config):
raise ValueError("--deepspeed requires a valid path to a config file")
config = deepspeed_parse_config(ds_config)
if (
"zero_optimization" in config
and "stage" in config["zero_optimization"]
and config["zero_optimization"]["stage"] == 3
):
_is_deepspeed_zero3_enabled = True
return _is_deepspeed_zero3_enabled
def deepspeed_zero3_enable(enable=True):
"""
``is_deepspeed_zero3_enabled()`` tries to derive automatically if DeepSpeed ZeRO 3 is going to be used by looking
at ``sys.argv`` which may or may contain information about where to find the DeepSpeed config if any.
This function allows for explicit enabling/disabling of this global flag.
Args:
enable: if set to ``True`` will make ``is_deepspeed_zero3_enabled()`` return ``True``
"""
global _is_deepspeed_zero3_enabled
_is_deepspeed_zero3_enabled = enable
def deepspeed_parse_config(ds_config):
"""
If ``ds_config`` isn't already a dict, read it from the config file.
If it's already a dict, return a copy of it, so that we can freely modify it.
"""
dep_version_check("deepspeed")
if isinstance(ds_config, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since some config params must be not set by users
config = deepcopy(ds_config)
elif isinstance(ds_config, str):
with io.open(ds_config, "r", encoding="utf-8") as f:
config = json.load(f)
else: else:
raise ValueError("expecting either a path to a config file or a pre-populated dict") return False
return config
def deepspeed_config():
if _deepspeed_config_hf_weak_ref is not None and _deepspeed_config_hf_weak_ref() is not None:
return _deepspeed_config_hf_weak_ref().config
else:
return None
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
@ -355,41 +461,16 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
""" """
import deepspeed import deepspeed
args = trainer.args
model = trainer.model model = trainer.model
config = deepspeed_parse_config(args.deepspeed) deepspeed_config_hf = trainer.args.deepspeed_config_hf
deepspeed_config_hf.config_finalize(trainer.args, model, num_training_steps)
# The following code translates relevant trainer's cl args into the DS config # resume config update - some bits like `model` and `num_training_steps` only become available during train
config = deepspeed_config_hf.config
# First to ensure that there is no mismatch between cl args values and presets in the config
# file, ask to not set in ds config file:
# - "train_batch_size",
# - "train_micro_batch_size_per_gpu",
# - "gradient_accumulation_steps"
bs_keys = ["train_batch_size", "train_micro_batch_size_per_gpu"]
if len([x for x in bs_keys if x in config.keys()]):
raise ValueError(
f"Do not include {bs_keys} entries in the ds config file, as they will be set via --per_device_train_batch_size or its default"
)
if "gradient_accumulation_steps" in config.keys():
raise ValueError(
"Do not include gradient_accumulation_steps entries in the ds config file, as they will be set via --gradient_accumulation_steps or its default"
)
# DeepSpeed does:
# train_batch_size = n_gpus * train_micro_batch_size_per_gpu * gradient_accumulation_steps
# therefore we just need to set:
config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
if "gradient_clipping" in config:
logger.info("Keeping the `gradient_clipping` config intact, ignoring any gradient clipping-specific cl args")
else: # override only if the ds config doesn't already have this section
config["gradient_clipping"] = args.max_grad_norm
# Optimizer + Scheduler # Optimizer + Scheduler
# Currently support combos: # Currently supported combos:
# 1. DS scheduler + DS optimizer: Yes # 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Yes # 2. HF scheduler + HF optimizer: Yes
# 3. DS scheduler + HF optimizer: Yes # 3. DS scheduler + HF optimizer: Yes
@ -402,36 +483,16 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
# 4. HF scheduler + DS optimizer: No # 4. HF scheduler + DS optimizer: No
optimizer = None optimizer = None
if "optimizer" in config: if "optimizer" not in config:
logger.info("Updating the `scheduler` config with other command line arguments") if deepspeed_config_hf.is_offload():
# to avoid inconsistent values of lr and warm up steps the command line args override config
params = dict(
lr=args.learning_rate,
betas=[args.adam_beta1, args.adam_beta2],
eps=args.adam_epsilon,
weight_decay=args.weight_decay,
)
for k, v in params.items():
if k in config["optimizer"]["params"]:
logger.info(f"setting optimizer.params.{k} to {v}")
config["optimizer"]["params"][k] = v
else: # override only if the ds config doesn't already have this section
if (
"zero_optimization" in config
and "cpu_offload" in config["zero_optimization"]
and config["zero_optimization"]["cpu_offload"] is True
):
raise ValueError("ZeRO Offload can only work with DeepSpeed optimizers") raise ValueError("ZeRO Offload can only work with DeepSpeed optimizers")
else:
# ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
# But trainer uses AdamW by default. # But trainer uses AdamW by default.
# To use other optimizers so using a different scheduler requires voiding warranty with: `zero_allow_untested_optimizer` trainer.create_optimizer()
trainer.create_optimizer() optimizer = trainer.optimizer
optimizer = trainer.optimizer # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
# flag that this is non-native optimizer config["zero_allow_untested_optimizer"] = True
config["zero_allow_untested_optimizer"] = True
# DS schedulers (deepspeed/runtime/lr_schedules.py): # DS schedulers (deepspeed/runtime/lr_schedules.py):
# #
@ -442,25 +503,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
# WarmupLR | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0 # WarmupLR | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0
# WarmupDecayLR| linear | get_linear_schedule_with_warmup | # WarmupDecayLR| linear | get_linear_schedule_with_warmup |
lr_scheduler = None lr_scheduler = None
if "scheduler" in config: if "scheduler" not in config:
logger.info("Updating the `scheduler` config with other command line arguments")
# the user won't easily know the correct num_training_steps should they use WarmupDecayLR,
# so let's set it to the correct value
if config["scheduler"]["type"] == "WarmupDecayLR":
logger.info(f"setting scheduler.params.total_num_steps to {num_training_steps}")
config["scheduler"]["params"]["total_num_steps"] = num_training_steps
# to avoid inconsistent values of lr and warmup steps the command line args override config
params = dict(
warmup_max_lr=args.learning_rate,
warmup_num_steps=args.warmup_steps,
)
for k, v in params.items():
if k in config["scheduler"]["params"]:
logger.info(f"setting scheduler.params.{k} to {v}")
config["scheduler"]["params"][k] = v
else: # override only if the ds config doesn't already have this section
if "optimizer" in config: if "optimizer" in config:
# to make this option work, we need to init DS optimizer first, then init HS scheduler, # to make this option work, we need to init DS optimizer first, then init HS scheduler,
# then pass the HS scheduler to DS init, which is not possible at the moment # then pass the HS scheduler to DS init, which is not possible at the moment
@ -469,43 +512,6 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
trainer.create_scheduler(num_training_steps=num_training_steps) trainer.create_scheduler(num_training_steps=num_training_steps)
lr_scheduler = trainer.lr_scheduler lr_scheduler = trainer.lr_scheduler
# fp16
if trainer.fp16_backend is not None:
# Deepspeed has 2 possible fp16 config entries:
# - `fp16`: for the native amp - it has a bunch of optional params but we won't set any here unless the user did the work
# - `amp`: which delegates amp work to apex (which needs to be available), but it cannot be used with any ZeRO features, so probably best to be avoided.
if trainer.fp16_backend == "apex":
if "amp" in config:
logger.info("Keeping the `amp` config intact, ignoring any amp-specific cl args")
else:
config["amp"] = {
"enabled": True,
"opt_level": args.fp16_opt_level,
}
elif trainer.fp16_backend == "amp":
if "fp16" in config:
logger.info("Keeping the `fp16` config intact, ignoring any fp16-specific cl args")
else:
config["fp16"] = {
"enabled": True,
}
# zero
if "zero_optimization" in config:
zero = config["zero_optimization"]
# now we know for sure if zero3 is enabled
deepspeed_zero3_enable(zero.get("stage") == 3)
# automatically assign the optimal config values based on model config
hidden_size = model.config.hidden_size
if zero.get("reduce_bucket_size") == 0:
zero["reduce_bucket_size"] = hidden_size * hidden_size
if zero.get("stage3_prefetch_bucket_size") == 0:
zero["stage3_prefetch_bucket_size"] = 0.9 * hidden_size * hidden_size
if zero.get("stage3_param_persistence_threshold") == 0:
zero["stage3_param_persistence_threshold"] = 10 * hidden_size
# keep for quick debug: # keep for quick debug:
# from pprint import pprint; pprint(config) # from pprint import pprint; pprint(config)

View File

@ -1122,7 +1122,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
import deepspeed import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model to avoid the overhead in time and memory copying it on CPU or each GPU first # this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
# XXX: param_dict will be added in deepspeed==0.3.16 and probably replaced by deepspeed_config
# with deepspeed.zero.Init(param_dict=deepspeed_config()):
with deepspeed.zero.Init(): with deepspeed.zero.Init():
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
else: else:

View File

@ -70,9 +70,6 @@ class TrainingArguments:
<https://docs.python.org/3/library/argparse.html#module-argparse>`__ arguments that can be specified on the command <https://docs.python.org/3/library/argparse.html#module-argparse>`__ arguments that can be specified on the command
line. line.
Parameters: Parameters:
output_dir (:obj:`str`): output_dir (:obj:`str`):
The output directory where the model predictions and checkpoints will be written. The output directory where the model predictions and checkpoints will be written.
@ -625,6 +622,14 @@ class TrainingArguments:
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
if self.deepspeed:
# - must be run very last in arg parsing, since it will use a lot of these settings.
# - must be run before the model is created.
from transformers.integrations import DeepSpeedConfigHF
# will be used later by the Trainer (leave self.deepspeed unmodified in case a user relies on it not to be modified)
self.deepspeed_config_hf = DeepSpeedConfigHF(self)
def __repr__(self): def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once # We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5) # those deprecated arguments are removed form TrainingArguments. (TODO: v5)

View File

@ -1,6 +1,6 @@
{ {
"fp16": { "fp16": {
"enabled": true, "enabled": "auto",
"loss_scale": 0, "loss_scale": 0,
"loss_scale_window": 1000, "loss_scale_window": 1000,
"initial_scale_power": 16, "initial_scale_power": 16,
@ -8,6 +8,25 @@
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": { "zero_optimization": {
"stage": 2, "stage": 2,
"allgather_partitions": true, "allgather_partitions": true,
@ -19,25 +38,10 @@
"cpu_offload": true "cpu_offload": true
}, },
"optimizer": { "gradient_accumulation_steps": "auto",
"type": "AdamW", "gradient_clipping": "auto",
"params": {
"lr": 3e-5,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
}
},
"steps_per_print": 2000, "steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false "wall_clock_breakdown": false
} }

View File

@ -1,6 +1,6 @@
{ {
"fp16": { "fp16": {
"enabled": true, "enabled": "auto",
"loss_scale": 0, "loss_scale": 0,
"loss_scale_window": 1000, "loss_scale_window": 1000,
"initial_scale_power": 16, "initial_scale_power": 16,
@ -8,41 +8,50 @@
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"zero_optimization": {
"stage": 3,
"cpu_offload": true,
"cpu_offload_params": true,
"cpu_offload_use_pin_memory" : true,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e14,
"reduce_bucket_size": 0,
"stage3_prefetch_bucket_size": 0,
"stage3_param_persistence_threshold": 0,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
},
"optimizer": { "optimizer": {
"type": "AdamW", "type": "AdamW",
"params": { "params": {
"lr": 3e-5, "lr": "auto",
"betas": [0.8, 0.999], "betas": "auto",
"eps": 1e-8, "eps": "auto",
"weight_decay": 3e-7 "weight_decay": "auto"
} }
}, },
"scheduler": { "scheduler": {
"type": "WarmupLR", "type": "WarmupLR",
"params": { "params": {
"warmup_min_lr": 0, "warmup_min_lr": "auto",
"warmup_max_lr": 3e-5, "warmup_max_lr": "auto",
"warmup_num_steps": 500 "warmup_num_steps": "auto"
} }
}, },
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e14,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000, "steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false "wall_clock_breakdown": false
} }

View File

@ -42,7 +42,7 @@ with ExtendSysPath(f"{bindir}/.."):
from test_trainer import TrainerIntegrationCommon # noqa from test_trainer import TrainerIntegrationCommon # noqa
if is_torch_available(): if is_torch_available():
from test_trainer import get_regression_trainer # noqa from test_trainer import RegressionModelConfig, RegressionPreTrainedModel, get_regression_trainer # noqa
set_seed(42) set_seed(42)
@ -66,6 +66,10 @@ def require_deepspeed(test_case):
return test_case return test_case
if is_deepspeed_available():
from deepspeed.utils import logger as deepspeed_logger # noqa
from transformers.integrations import deepspeed_config, is_deepspeed_zero3_enabled # noqa
ZERO2 = "zero2" ZERO2 = "zero2"
ZERO3 = "zero3" ZERO3 = "zero3"
stages = [ZERO2, ZERO3] stages = [ZERO2, ZERO3]
@ -115,12 +119,6 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f: with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f:
self.ds_config_dict[ZERO3] = json.load(f) self.ds_config_dict[ZERO3] = json.load(f)
def tearDown(self):
# XXX: Fixme - this is a temporary band-aid since this global variable impacts other tests
import transformers
transformers.integrations._is_deepspeed_zero3_enabled = None
def get_config_dict(self, stage): def get_config_dict(self, stage):
"""As the tests modify the dict, always make a copy""" """As the tests modify the dict, always make a copy"""
config = deepcopy(self.ds_config_dict[stage]) config = deepcopy(self.ds_config_dict[stage])
@ -173,25 +171,65 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero2_dict) trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero2_dict)
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
trainer.train() trainer.train()
self.assertTrue("HF scheduler + DeepSpeed optimizer combination is not possible" in str(context.exception)) self.assertTrue(
"HF scheduler + DeepSpeed optimizer combination is not possible" in str(context.exception),
f"got exception: {context.exception}",
)
def test_hf_optimizer_with_offload(self): def test_stage3_nvme_offload(self):
# must not allow non-DS optimizer when using ZERO-offload
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
ds_config_zero2_dict = self.get_config_dict(ZERO2) # this actually doesn't have to be on NVMe, any storage will do since this test only
del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer # runs a simple check that we can use some directory as if it were NVMe
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = True nvme_path = self.get_auto_remove_tmp_dir()
# sanity check - should the default config change nvme_config = dict(device="nvme", nvme_path=nvme_path)
assert ( ds_config_zero3_dict = self.get_config_dict(ZERO3)
"cpu_offload" in ds_config_zero2_dict["zero_optimization"] ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config
and ds_config_zero2_dict["zero_optimization"]["cpu_offload"] is True ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config
), "ensure the config is set up correctly" trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero3_dict)
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero2_dict) with CaptureLogger(deepspeed_logger) as cs:
with self.assertRaises(Exception) as context:
trainer.train() trainer.train()
self.assertTrue("ZeRO Offload can only work with DeepSpeed optimizers" in str(context.exception)) self.assertIn("DeepSpeed info", cs.out, "expected DeepSpeed logger output but got none")
# --- These tests need to run on both zero stages --- # # --- These tests need to run on both zero stages --- #
@parameterized.expand(stages)
def test_fp32(self, stage):
ds_config_dict = self.get_config_dict(stage)
ds_config_dict["fp16"]["enabled"] = False # force non-fp16 mode
# XXX: do we go via from_pretrained in zero 3 here? need to test zero.Init(dtype=torch.float)
# XXX: rewrite this test once fp32 is supported by DeepSpeed
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_dict)
with self.assertRaises(Exception) as context:
trainer.train()
self.assertIn(
"ZeRO is only supported if fp16 is enabled",
str(context.exception),
f"got exception: {context.exception}",
)
@parameterized.expand(stages)
def test_hf_optimizer_with_offload(self, stage):
# must not allow non-DS optimizer when using ZERO-offload
ds_config_dict = self.get_config_dict(stage)
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
# force cpu offload
if stage == "stage2":
ds_config_dict["zero_optimization"]["cpu_offload"] = True
elif stage == "stage3":
ds_config_dict["zero_optimization"]["offload_optimizer"]["device"] = "cpu"
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_dict)
with self.assertRaises(Exception) as context:
trainer.train()
self.assertIn(
"ZeRO Offload can only work with DeepSpeed optimizers",
str(context.exception),
f"got exception: {context.exception}",
)
@parameterized.expand(stages) @parameterized.expand(stages)
def test_fake_notebook_no_launcher(self, stage): def test_fake_notebook_no_launcher(self, stage):
# this setup emulates a notebook where a launcher needs to be emulated by hand # this setup emulates a notebook where a launcher needs to be emulated by hand
@ -199,14 +237,12 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
# note that unittest resets sys.stdout each test, so `CaptureStd` will work here to capture # note that unittest resets sys.stdout each test, so `CaptureStd` will work here to capture
# DeepSpeed log if this test happens to run first in this pytest worker. But it will fail if # DeepSpeed log if this test happens to run first in this pytest worker. But it will fail if
# it's run not as a first test as `sys.stdout` will no longer be the same. So we either have # it's run not as a first test as `sys.stdout` will no longer be the same. So we either have
# to reset `logger.handlers[0].setStream(sys.stdout)` or directly capture from the logger. # to reset `deepspeed_logger.handlers[0].setStream(sys.stdout)` or directly capture from the deepspeed_logger.
from deepspeed.utils import logger with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file[stage])
with CaptureLogger(logger) as cs: with CaptureLogger(deepspeed_logger) as cs:
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file[stage])
trainer.train() trainer.train()
assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none" self.assertIn("DeepSpeed info", cs.out, "expected DeepSpeed logger output but got none")
@parameterized.expand(stages) @parameterized.expand(stages)
def test_early_get_last_lr(self, stage): def test_early_get_last_lr(self, stage):
@ -425,6 +461,38 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1) self.check_trainer_state_are_the_same(state, state1)
def test_config_object(self):
# test that we can switch from zero2 to zero3 in the same process for example
# test is_zero, etc.
output_dir = self.get_auto_remove_tmp_dir()
kwargs = dict(output_dir=output_dir, train_len=8)
with mockenv_context(**self.dist_env_1_gpu):
ds_config_zero3_dict = self.get_config_dict("zero3")
ds_config_zero2_dict = self.get_config_dict("zero2")
trainer = get_regression_trainer(deepspeed=ds_config_zero3_dict, **kwargs)
self.assertTrue(is_deepspeed_zero3_enabled())
# test we can repeat that and with train this time
trainer = get_regression_trainer(deepspeed=ds_config_zero3_dict, **kwargs)
trainer.train()
self.assertTrue(is_deepspeed_zero3_enabled())
# test zero3 is disabled
trainer = get_regression_trainer(deepspeed=ds_config_zero2_dict, **kwargs)
self.assertFalse(is_deepspeed_zero3_enabled())
# check config obj
config = deepspeed_config()
self.assertTrue(bool(config), "Deepspeed config should be accessible")
del trainer
# now weakref should gc the global and we shouldn't get anything here
config = deepspeed_config()
self.assertFalse(is_deepspeed_zero3_enabled())
self.assertFalse(bool(config), "Deepspeed config should not be accessible")
@slow @slow
@require_deepspeed @require_deepspeed
@ -557,6 +625,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
--adafactor --adafactor
--source_lang en --source_lang en
--target_lang ro --target_lang ro
--report_to none
""".split() """.split()
args.extend(["--source_prefix", '"translate English to Romanian: "']) args.extend(["--source_prefix", '"translate English to Romanian: "'])
@ -626,6 +695,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
--num_train_epochs 1 --num_train_epochs 1
--warmup_steps 8 --warmup_steps 8
--block_size 128 --block_size 128
--report_to none
""".split() """.split()
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split() ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()

View File

@ -213,16 +213,21 @@ if is_torch_available():
label_names = kwargs.get("label_names", None) label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names) train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names) eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
if pretrained:
config = RegressionModelConfig(a=a, b=b, double_output=double_output) model_init = kwargs.pop("model_init", None)
model = RegressionPreTrainedModel(config) if model_init is not None:
model = None
else: else:
model = RegressionModel(a=a, b=b, double_output=double_output) if pretrained:
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
else:
model = RegressionModel(a=a, b=b, double_output=double_output)
compute_metrics = kwargs.pop("compute_metrics", None) compute_metrics = kwargs.pop("compute_metrics", None)
data_collator = kwargs.pop("data_collator", None) data_collator = kwargs.pop("data_collator", None)
optimizers = kwargs.pop("optimizers", (None, None)) optimizers = kwargs.pop("optimizers", (None, None))
output_dir = kwargs.pop("output_dir", "./regression") output_dir = kwargs.pop("output_dir", "./regression")
model_init = kwargs.pop("model_init", None)
args = RegressionTrainingArguments(output_dir, a=a, b=b, **kwargs) args = RegressionTrainingArguments(output_dir, a=a, b=b, **kwargs)
return Trainer( return Trainer(