[trainer/deepspeed] load_best_model (reimplement re-init) (#17151)

* [trainer/deepspeed] load_best_model

* to sync with DS PR #1947

* simplify

* rework load_best_model test

* cleanup

* bump deepspeed>=0.6.5

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Stas Bekman 2022-06-02 09:14:21 -07:00 committed by GitHub
parent 046c5ea906
commit 2f59ad1609
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 93 deletions

View File

@ -27,7 +27,7 @@ To create the package for pypi.
3. Unpin specific versions from setup.py that use a git install. 3. Unpin specific versions from setup.py that use a git install.
4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the 4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
message: "Release: <VERSION>" and push. message: "Release: <VERSION>" and push.
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs) 5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
@ -103,7 +103,7 @@ _deps = [
"cookiecutter==1.7.3", "cookiecutter==1.7.3",
"dataclasses", "dataclasses",
"datasets", "datasets",
"deepspeed>=0.6.4", "deepspeed>=0.6.5",
"dill<0.3.5", "dill<0.3.5",
"fairscale>0.3", "fairscale>0.3",
"faiss-cpu", "faiss-cpu",

View File

@ -364,18 +364,6 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
return optimizer, lr_scheduler return optimizer, lr_scheduler
def deepspeed_reinit(trainer):
"""
this is a temp hack based on: https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until
Deepspeed fixes a bug where it can't resume from a checkpoint after it did some stepping
https://github.com/microsoft/DeepSpeed/issues/1612
"""
import deepspeed
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**trainer.deepspeed_initialize_kwargs)
return deepspeed_engine, optimizer, lr_scheduler
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False): def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
""" """
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
@ -390,6 +378,10 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
Returns: model, optimizer, lr_scheduler Returns: model, optimizer, lr_scheduler
We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on:
https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it
can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612
""" """
import deepspeed import deepspeed
from deepspeed.utils import logger as ds_logger from deepspeed.utils import logger as ds_logger
@ -397,8 +389,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
model = trainer.model model = trainer.model
args = trainer.args args = trainer.args
if hasattr(trainer, "hf_deepspeed_config_orig"):
hf_deepspeed_config = deepcopy(trainer.hf_deepspeed_config_orig)
else:
hf_deepspeed_config = args.hf_deepspeed_config
trainer.hf_deepspeed_config_orig = deepcopy(args.hf_deepspeed_config)
# resume config update - some bits like `model` and `num_training_steps` only become available during train # resume config update - some bits like `model` and `num_training_steps` only become available during train
hf_deepspeed_config = args.hf_deepspeed_config
hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
config = hf_deepspeed_config.config config = hf_deepspeed_config.config
@ -416,6 +413,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
optimizer, lr_scheduler = None, None optimizer, lr_scheduler = None, None
model_parameters = None model_parameters = None
else: else:
trainer.optimizer = None # important for when deepspeed_init is used as re-init
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps) optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
@ -432,9 +430,6 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
# stash kwargs to enabled a later deepspeed_reinit
trainer.deepspeed_initialize_kwargs = kwargs
if resume_from_checkpoint is not None: if resume_from_checkpoint is not None:
# it's possible that the user is trying to resume from model_path, which doesn't necessarily # it's possible that the user is trying to resume from model_path, which doesn't necessarily

View File

@ -9,7 +9,7 @@ deps = {
"cookiecutter": "cookiecutter==1.7.3", "cookiecutter": "cookiecutter==1.7.3",
"dataclasses": "dataclasses", "dataclasses": "dataclasses",
"datasets": "datasets", "datasets": "datasets",
"deepspeed": "deepspeed>=0.6.4", "deepspeed": "deepspeed>=0.6.5",
"dill": "dill<0.3.5", "dill": "dill<0.3.5",
"fairscale": "fairscale>0.3", "fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu", "faiss-cpu": "faiss-cpu",

View File

@ -65,7 +65,7 @@ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check from .dependency_versions_check import dep_version_check
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
@ -1749,16 +1749,23 @@ class Trainer:
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path): if os.path.exists(best_model_path):
if self.deepspeed: if self.deepspeed:
if self.model_wrapped is not None:
# this removes the pre-hooks from the previous engine
self.model_wrapped.destroy()
self.model_wrapped = None
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self) deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self,
num_training_steps=self.args.max_steps,
resume_from_checkpoint=self.state.best_model_checkpoint,
)
self.model = deepspeed_engine.module self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine self.deepspeed = deepspeed_engine
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
else: else:
# We load the model state dict on the CPU to avoid an OOM error. # We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu") state_dict = torch.load(best_model_path, map_location="cpu")

View File

@ -20,6 +20,8 @@ import os
import unittest import unittest
from copy import deepcopy from copy import deepcopy
import datasets
from parameterized import parameterized from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging from transformers import AutoModel, TrainingArguments, is_torch_available, logging
@ -195,28 +197,7 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out) self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
@require_deepspeed class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
@require_torch_gpu
class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
"""
This class is for testing directly via get_regression_trainer
It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
which we can re-use here.
Important: this class' setup can only work with a single gpu because it runs within the current
pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.
Note: if any of the tests of this class get run there will be at least one gpu occupied by them
until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
won't be released until this pytest worker exits.
This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
is not a bug.
"""
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -252,6 +233,29 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
# As some tests modify the dict, always make a copy # As some tests modify the dict, always make a copy
return deepcopy(self.ds_config_dict[stage]) return deepcopy(self.ds_config_dict[stage])
@require_deepspeed
@require_torch_gpu
class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, TrainerIntegrationCommon):
"""
This class is for testing directly via get_regression_trainer
It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
which we can re-use here.
Important: this class' setup can only work with a single gpu because it runs within the current
pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.
Note: if any of the tests of this class get run there will be at least one gpu occupied by them
until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
won't be released until this pytest worker exits.
This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
is not a bug.
"""
# --- These tests are enough to run on one of zero stages --- # # --- These tests are enough to run on one of zero stages --- #
def test_hf_ds_config_mismatch(self): def test_hf_ds_config_mismatch(self):
@ -725,6 +729,95 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(is_deepspeed_zero3_enabled()) self.assertFalse(is_deepspeed_zero3_enabled())
self.assertFalse(bool(config), "Deepspeed config should not be accessible") self.assertFalse(bool(config), "Deepspeed config should not be accessible")
@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_load_best_model(self, stage, dtype):
# Test that forced deepspeed reinit doesn't break the model. the forced re-init after
# loading the best model in Trainer is there to workaround this bug in Deepspeed
# https://github.com/microsoft/DeepSpeed/issues/1612
#
# The test is derived from a repro script submitted in this Issue:
# https://github.com/huggingface/transformers/issues/17114
#
# One additional feature of this test is that we use a non-AdamW optimizer to test that
# deepspeed doesn't fallback to AdamW, which would prevent the optimizer states from loading
# correctly
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer # noqa
output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False, before=False)
ds_config_dict = self.get_config_dict(stage)
del ds_config_dict["optimizer"] # will use HF Trainer optimizer
del ds_config_dict["scheduler"] # will use HF Trainer scheduler
# must use this setting to get the reload path exercised
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
tokenizer = T5Tokenizer.from_pretrained(T5_TINY)
model = T5ForConditionalGeneration.from_pretrained(T5_TINY)
def _add_eos_to_examples(example):
example["input_text"] = f"question: {example['question']} context: {example['context']}"
example["target_text"] = example["answers"]["text"][0] if len(example["answers"]["text"]) > 0 else ""
return example
def _convert_to_features(example_batch):
input_encodings = tokenizer.batch_encode_plus(
example_batch["input_text"], pad_to_max_length=True, max_length=512, truncation=True
)
target_encodings = tokenizer.batch_encode_plus(
example_batch["target_text"], pad_to_max_length=True, max_length=16, truncation=True
)
encodings = {
"input_ids": input_encodings["input_ids"],
"attention_mask": input_encodings["attention_mask"],
"labels": target_encodings["input_ids"],
}
return encodings
def get_dataset():
data_file = str(self.tests_dir / "fixtures/tests_samples/SQUAD/sample.json")
data_files = dict(train=data_file, validation=data_file)
raw_datasets = datasets.load_dataset("json", data_files=data_files, field="data")
train_dataset = raw_datasets["train"].map(_add_eos_to_examples).map(_convert_to_features, batched=True)
valid_dataset = deepcopy(train_dataset)
return train_dataset, valid_dataset
train_dataset, eval_dataset = get_dataset()
args_dict = {
"per_gpu_train_batch_size": 1,
"per_gpu_eval_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"num_train_epochs": 1,
"do_train": True,
"do_eval": True,
"optim": "adafactor",
"evaluation_strategy": "steps",
"eval_steps": 1,
"save_strategy": "steps",
"save_steps": 1,
"load_best_model_at_end": True,
"max_steps": 1,
"deepspeed": ds_config_dict,
}
with mockenv_context(**self.dist_env_1_gpu):
training_args = TrainingArguments(output_dir, **args_dict)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train() # crash 1 was here
trainer.evaluate() # crash 2 was here
@slow @slow
@require_deepspeed @require_deepspeed
@ -1035,50 +1128,3 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
with CaptureStderr() as cs: with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env()) execute_subprocess_async(cmd, env=self.get_env())
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err) self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_load_best_model(self, stage, dtype):
# this test exercises --load_best_model_at_end - the key is being able to resume after some training
data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_name_or_path {T5_TINY}
--tokenizer_name {T5_TINY}
--train_file {data_dir}/train.json
--validation_file {data_dir}/val.json
--output_dir {output_dir}
--overwrite_output_dir
--source_lang en
--target_lang ro
--do_train
--max_train_samples 3
--do_eval
--max_eval_samples 1
--logging_strategy steps
--logging_steps 1
--evaluation_strategy steps
--eval_steps 1
--save_strategy steps
--save_steps 1
--load_best_model_at_end
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--num_train_epochs 1
--report_to none
""".split()
args.extend(["--source_prefix", "translate English to Romanian: "])
args.extend([f"--{dtype}"])
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
launcher = get_launcher(distributed=False)
cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
with CaptureStd() as cs:
execute_subprocess_async(cmd, env=self.get_env())
# enough to test it didn't fail
self.assertIn("DeepSpeed info", cs.out)