mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[trainer/deepspeed] load_best_model (reimplement re-init) (#17151)
* [trainer/deepspeed] load_best_model * to sync with DS PR #1947 * simplify * rework load_best_model test * cleanup * bump deepspeed>=0.6.5 Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
parent
046c5ea906
commit
2f59ad1609
4
setup.py
4
setup.py
@ -27,7 +27,7 @@ To create the package for pypi.
|
|||||||
|
|
||||||
3. Unpin specific versions from setup.py that use a git install.
|
3. Unpin specific versions from setup.py that use a git install.
|
||||||
|
|
||||||
4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
|
4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
|
||||||
message: "Release: <VERSION>" and push.
|
message: "Release: <VERSION>" and push.
|
||||||
|
|
||||||
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
|
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
|
||||||
@ -103,7 +103,7 @@ _deps = [
|
|||||||
"cookiecutter==1.7.3",
|
"cookiecutter==1.7.3",
|
||||||
"dataclasses",
|
"dataclasses",
|
||||||
"datasets",
|
"datasets",
|
||||||
"deepspeed>=0.6.4",
|
"deepspeed>=0.6.5",
|
||||||
"dill<0.3.5",
|
"dill<0.3.5",
|
||||||
"fairscale>0.3",
|
"fairscale>0.3",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
|
@ -364,18 +364,6 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
|
|||||||
return optimizer, lr_scheduler
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def deepspeed_reinit(trainer):
|
|
||||||
"""
|
|
||||||
this is a temp hack based on: https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until
|
|
||||||
Deepspeed fixes a bug where it can't resume from a checkpoint after it did some stepping
|
|
||||||
https://github.com/microsoft/DeepSpeed/issues/1612
|
|
||||||
"""
|
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**trainer.deepspeed_initialize_kwargs)
|
|
||||||
return deepspeed_engine, optimizer, lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
|
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
|
||||||
"""
|
"""
|
||||||
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
|
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
|
||||||
@ -390,6 +378,10 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
|
|||||||
|
|
||||||
Returns: model, optimizer, lr_scheduler
|
Returns: model, optimizer, lr_scheduler
|
||||||
|
|
||||||
|
We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on:
|
||||||
|
https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it
|
||||||
|
can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from deepspeed.utils import logger as ds_logger
|
from deepspeed.utils import logger as ds_logger
|
||||||
@ -397,8 +389,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
|
|||||||
model = trainer.model
|
model = trainer.model
|
||||||
args = trainer.args
|
args = trainer.args
|
||||||
|
|
||||||
|
if hasattr(trainer, "hf_deepspeed_config_orig"):
|
||||||
|
hf_deepspeed_config = deepcopy(trainer.hf_deepspeed_config_orig)
|
||||||
|
else:
|
||||||
|
hf_deepspeed_config = args.hf_deepspeed_config
|
||||||
|
trainer.hf_deepspeed_config_orig = deepcopy(args.hf_deepspeed_config)
|
||||||
|
|
||||||
# resume config update - some bits like `model` and `num_training_steps` only become available during train
|
# resume config update - some bits like `model` and `num_training_steps` only become available during train
|
||||||
hf_deepspeed_config = args.hf_deepspeed_config
|
|
||||||
hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
|
hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
|
||||||
config = hf_deepspeed_config.config
|
config = hf_deepspeed_config.config
|
||||||
|
|
||||||
@ -416,6 +413,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
|
|||||||
optimizer, lr_scheduler = None, None
|
optimizer, lr_scheduler = None, None
|
||||||
model_parameters = None
|
model_parameters = None
|
||||||
else:
|
else:
|
||||||
|
trainer.optimizer = None # important for when deepspeed_init is used as re-init
|
||||||
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
|
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
|
||||||
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
|
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
|
||||||
|
|
||||||
@ -432,9 +430,6 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
|
|||||||
|
|
||||||
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
|
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
|
||||||
|
|
||||||
# stash kwargs to enabled a later deepspeed_reinit
|
|
||||||
trainer.deepspeed_initialize_kwargs = kwargs
|
|
||||||
|
|
||||||
if resume_from_checkpoint is not None:
|
if resume_from_checkpoint is not None:
|
||||||
|
|
||||||
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
|
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
|
||||||
|
@ -9,7 +9,7 @@ deps = {
|
|||||||
"cookiecutter": "cookiecutter==1.7.3",
|
"cookiecutter": "cookiecutter==1.7.3",
|
||||||
"dataclasses": "dataclasses",
|
"dataclasses": "dataclasses",
|
||||||
"datasets": "datasets",
|
"datasets": "datasets",
|
||||||
"deepspeed": "deepspeed>=0.6.4",
|
"deepspeed": "deepspeed>=0.6.5",
|
||||||
"dill": "dill<0.3.5",
|
"dill": "dill<0.3.5",
|
||||||
"fairscale": "fairscale>0.3",
|
"fairscale": "fairscale>0.3",
|
||||||
"faiss-cpu": "faiss-cpu",
|
"faiss-cpu": "faiss-cpu",
|
||||||
|
@ -65,7 +65,7 @@ from . import __version__
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||||
from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
|
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
|
||||||
from .dependency_versions_check import dep_version_check
|
from .dependency_versions_check import dep_version_check
|
||||||
from .modelcard import TrainingSummary
|
from .modelcard import TrainingSummary
|
||||||
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
|
||||||
@ -1749,16 +1749,23 @@ class Trainer:
|
|||||||
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if os.path.exists(best_model_path):
|
if os.path.exists(best_model_path):
|
||||||
if self.deepspeed:
|
if self.deepspeed:
|
||||||
|
|
||||||
|
if self.model_wrapped is not None:
|
||||||
|
# this removes the pre-hooks from the previous engine
|
||||||
|
self.model_wrapped.destroy()
|
||||||
|
self.model_wrapped = None
|
||||||
|
|
||||||
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
|
# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
|
||||||
deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
|
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
|
||||||
|
self,
|
||||||
|
num_training_steps=self.args.max_steps,
|
||||||
|
resume_from_checkpoint=self.state.best_model_checkpoint,
|
||||||
|
)
|
||||||
self.model = deepspeed_engine.module
|
self.model = deepspeed_engine.module
|
||||||
self.model_wrapped = deepspeed_engine
|
self.model_wrapped = deepspeed_engine
|
||||||
self.deepspeed = deepspeed_engine
|
self.deepspeed = deepspeed_engine
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.lr_scheduler = lr_scheduler
|
self.lr_scheduler = lr_scheduler
|
||||||
self.deepspeed.load_checkpoint(
|
|
||||||
self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# We load the model state dict on the CPU to avoid an OOM error.
|
# We load the model state dict on the CPU to avoid an OOM error.
|
||||||
state_dict = torch.load(best_model_path, map_location="cpu")
|
state_dict = torch.load(best_model_path, map_location="cpu")
|
||||||
|
@ -20,6 +20,8 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
|
||||||
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
|
||||||
@ -195,28 +197,7 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||||
|
|
||||||
|
|
||||||
@require_deepspeed
|
class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
|
||||||
@require_torch_gpu
|
|
||||||
class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|
||||||
"""
|
|
||||||
|
|
||||||
This class is for testing directly via get_regression_trainer
|
|
||||||
|
|
||||||
It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
|
|
||||||
which we can re-use here.
|
|
||||||
|
|
||||||
Important: this class' setup can only work with a single gpu because it runs within the current
|
|
||||||
pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.
|
|
||||||
|
|
||||||
Note: if any of the tests of this class get run there will be at least one gpu occupied by them
|
|
||||||
until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
|
|
||||||
won't be released until this pytest worker exits.
|
|
||||||
|
|
||||||
This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
|
|
||||||
processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
|
|
||||||
is not a bug.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
@ -252,6 +233,29 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# As some tests modify the dict, always make a copy
|
# As some tests modify the dict, always make a copy
|
||||||
return deepcopy(self.ds_config_dict[stage])
|
return deepcopy(self.ds_config_dict[stage])
|
||||||
|
|
||||||
|
|
||||||
|
@require_deepspeed
|
||||||
|
@require_torch_gpu
|
||||||
|
class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, TrainerIntegrationCommon):
|
||||||
|
"""
|
||||||
|
|
||||||
|
This class is for testing directly via get_regression_trainer
|
||||||
|
|
||||||
|
It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
|
||||||
|
which we can re-use here.
|
||||||
|
|
||||||
|
Important: this class' setup can only work with a single gpu because it runs within the current
|
||||||
|
pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.
|
||||||
|
|
||||||
|
Note: if any of the tests of this class get run there will be at least one gpu occupied by them
|
||||||
|
until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
|
||||||
|
won't be released until this pytest worker exits.
|
||||||
|
|
||||||
|
This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
|
||||||
|
processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
|
||||||
|
is not a bug.
|
||||||
|
"""
|
||||||
|
|
||||||
# --- These tests are enough to run on one of zero stages --- #
|
# --- These tests are enough to run on one of zero stages --- #
|
||||||
|
|
||||||
def test_hf_ds_config_mismatch(self):
|
def test_hf_ds_config_mismatch(self):
|
||||||
@ -725,6 +729,95 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertFalse(is_deepspeed_zero3_enabled())
|
self.assertFalse(is_deepspeed_zero3_enabled())
|
||||||
self.assertFalse(bool(config), "Deepspeed config should not be accessible")
|
self.assertFalse(bool(config), "Deepspeed config should not be accessible")
|
||||||
|
|
||||||
|
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||||
|
def test_load_best_model(self, stage, dtype):
|
||||||
|
# Test that forced deepspeed reinit doesn't break the model. the forced re-init after
|
||||||
|
# loading the best model in Trainer is there to workaround this bug in Deepspeed
|
||||||
|
# https://github.com/microsoft/DeepSpeed/issues/1612
|
||||||
|
#
|
||||||
|
# The test is derived from a repro script submitted in this Issue:
|
||||||
|
# https://github.com/huggingface/transformers/issues/17114
|
||||||
|
#
|
||||||
|
# One additional feature of this test is that we use a non-AdamW optimizer to test that
|
||||||
|
# deepspeed doesn't fallback to AdamW, which would prevent the optimizer states from loading
|
||||||
|
# correctly
|
||||||
|
|
||||||
|
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer # noqa
|
||||||
|
|
||||||
|
output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False, before=False)
|
||||||
|
|
||||||
|
ds_config_dict = self.get_config_dict(stage)
|
||||||
|
del ds_config_dict["optimizer"] # will use HF Trainer optimizer
|
||||||
|
del ds_config_dict["scheduler"] # will use HF Trainer scheduler
|
||||||
|
# must use this setting to get the reload path exercised
|
||||||
|
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True
|
||||||
|
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained(T5_TINY)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(T5_TINY)
|
||||||
|
|
||||||
|
def _add_eos_to_examples(example):
|
||||||
|
example["input_text"] = f"question: {example['question']} context: {example['context']}"
|
||||||
|
example["target_text"] = example["answers"]["text"][0] if len(example["answers"]["text"]) > 0 else ""
|
||||||
|
return example
|
||||||
|
|
||||||
|
def _convert_to_features(example_batch):
|
||||||
|
input_encodings = tokenizer.batch_encode_plus(
|
||||||
|
example_batch["input_text"], pad_to_max_length=True, max_length=512, truncation=True
|
||||||
|
)
|
||||||
|
target_encodings = tokenizer.batch_encode_plus(
|
||||||
|
example_batch["target_text"], pad_to_max_length=True, max_length=16, truncation=True
|
||||||
|
)
|
||||||
|
|
||||||
|
encodings = {
|
||||||
|
"input_ids": input_encodings["input_ids"],
|
||||||
|
"attention_mask": input_encodings["attention_mask"],
|
||||||
|
"labels": target_encodings["input_ids"],
|
||||||
|
}
|
||||||
|
|
||||||
|
return encodings
|
||||||
|
|
||||||
|
def get_dataset():
|
||||||
|
data_file = str(self.tests_dir / "fixtures/tests_samples/SQUAD/sample.json")
|
||||||
|
data_files = dict(train=data_file, validation=data_file)
|
||||||
|
raw_datasets = datasets.load_dataset("json", data_files=data_files, field="data")
|
||||||
|
train_dataset = raw_datasets["train"].map(_add_eos_to_examples).map(_convert_to_features, batched=True)
|
||||||
|
valid_dataset = deepcopy(train_dataset)
|
||||||
|
return train_dataset, valid_dataset
|
||||||
|
|
||||||
|
train_dataset, eval_dataset = get_dataset()
|
||||||
|
|
||||||
|
args_dict = {
|
||||||
|
"per_gpu_train_batch_size": 1,
|
||||||
|
"per_gpu_eval_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"num_train_epochs": 1,
|
||||||
|
"do_train": True,
|
||||||
|
"do_eval": True,
|
||||||
|
"optim": "adafactor",
|
||||||
|
"evaluation_strategy": "steps",
|
||||||
|
"eval_steps": 1,
|
||||||
|
"save_strategy": "steps",
|
||||||
|
"save_steps": 1,
|
||||||
|
"load_best_model_at_end": True,
|
||||||
|
"max_steps": 1,
|
||||||
|
"deepspeed": ds_config_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
with mockenv_context(**self.dist_env_1_gpu):
|
||||||
|
|
||||||
|
training_args = TrainingArguments(output_dir, **args_dict)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
)
|
||||||
|
trainer.train() # crash 1 was here
|
||||||
|
trainer.evaluate() # crash 2 was here
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_deepspeed
|
@require_deepspeed
|
||||||
@ -1035,50 +1128,3 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
with CaptureStderr() as cs:
|
with CaptureStderr() as cs:
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
|
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)
|
||||||
|
|
||||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
|
||||||
def test_load_best_model(self, stage, dtype):
|
|
||||||
# this test exercises --load_best_model_at_end - the key is being able to resume after some training
|
|
||||||
|
|
||||||
data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
|
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
|
||||||
args = f"""
|
|
||||||
--model_name_or_path {T5_TINY}
|
|
||||||
--tokenizer_name {T5_TINY}
|
|
||||||
--train_file {data_dir}/train.json
|
|
||||||
--validation_file {data_dir}/val.json
|
|
||||||
--output_dir {output_dir}
|
|
||||||
--overwrite_output_dir
|
|
||||||
--source_lang en
|
|
||||||
--target_lang ro
|
|
||||||
--do_train
|
|
||||||
--max_train_samples 3
|
|
||||||
--do_eval
|
|
||||||
--max_eval_samples 1
|
|
||||||
--logging_strategy steps
|
|
||||||
--logging_steps 1
|
|
||||||
--evaluation_strategy steps
|
|
||||||
--eval_steps 1
|
|
||||||
--save_strategy steps
|
|
||||||
--save_steps 1
|
|
||||||
--load_best_model_at_end
|
|
||||||
--per_device_train_batch_size 1
|
|
||||||
--per_device_eval_batch_size 1
|
|
||||||
--num_train_epochs 1
|
|
||||||
--report_to none
|
|
||||||
""".split()
|
|
||||||
args.extend(["--source_prefix", "translate English to Romanian: "])
|
|
||||||
|
|
||||||
args.extend([f"--{dtype}"])
|
|
||||||
|
|
||||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
|
|
||||||
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
|
|
||||||
launcher = get_launcher(distributed=False)
|
|
||||||
|
|
||||||
cmd = launcher + script + args + ds_args
|
|
||||||
# keep for quick debug
|
|
||||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
|
||||||
with CaptureStd() as cs:
|
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
|
||||||
# enough to test it didn't fail
|
|
||||||
self.assertIn("DeepSpeed info", cs.out)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user