mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
[Deepspeed] add many more models to the model zoo test (#12695)
* model zoo take 2 * add deberta * new param for zero2 * doc update * doc update * add layoutlm * bump deepspeed * add deberta-v2, funnel, longformer * new models * style * add t5_v1 * update TAPAS status * reorg problematic models * move doc to another PR * style * fix checkpoint check test * making progress on more models running * cleanup * new version * cleanup
This commit is contained in:
parent
9aeacfe0ff
commit
f861504466
4
setup.py
4
setup.py
@ -19,7 +19,7 @@ To create the package for pypi.
|
|||||||
|
|
||||||
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
|
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
|
||||||
documentation.
|
documentation.
|
||||||
|
|
||||||
If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
|
If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
|
||||||
for the post-release and run `make fix-copies` on the main branch as well.
|
for the post-release and run `make fix-copies` on the main branch as well.
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ _deps = [
|
|||||||
"cookiecutter==1.7.3",
|
"cookiecutter==1.7.3",
|
||||||
"dataclasses",
|
"dataclasses",
|
||||||
"datasets",
|
"datasets",
|
||||||
"deepspeed>=0.6.0",
|
"deepspeed>=0.6.4",
|
||||||
"fairscale>0.3",
|
"fairscale>0.3",
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
|
@ -9,7 +9,7 @@ deps = {
|
|||||||
"cookiecutter": "cookiecutter==1.7.3",
|
"cookiecutter": "cookiecutter==1.7.3",
|
||||||
"dataclasses": "dataclasses",
|
"dataclasses": "dataclasses",
|
||||||
"datasets": "datasets",
|
"datasets": "datasets",
|
||||||
"deepspeed": "deepspeed>=0.6.0",
|
"deepspeed": "deepspeed>=0.6.4",
|
||||||
"fairscale": "fairscale>0.3",
|
"fairscale": "fairscale>0.3",
|
||||||
"faiss-cpu": "faiss-cpu",
|
"faiss-cpu": "faiss-cpu",
|
||||||
"fastapi": "fastapi",
|
"fastapi": "fastapi",
|
||||||
|
@ -522,7 +522,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# see the note above how to get identical loss on a small bs
|
# see the note above how to get identical loss on a small bs
|
||||||
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2)
|
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2)
|
||||||
|
|
||||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage):
|
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
|
||||||
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
||||||
|
|
||||||
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
|
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
|
||||||
@ -534,7 +534,8 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown stage {stage}")
|
raise ValueError(f"unknown stage {stage}")
|
||||||
|
|
||||||
ds_file_list.append("zero_pp_rank_0_mp_rank_00_optim_states.pt")
|
if dtype == "bf16":
|
||||||
|
ds_file_list.append("bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt")
|
||||||
|
|
||||||
for step in range(freq, total, freq):
|
for step in range(freq, total, freq):
|
||||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||||
@ -578,7 +579,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
total = int(self.n_epochs * 64 / self.batch_size)
|
total = int(self.n_epochs * 64 / self.batch_size)
|
||||||
self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage)
|
self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage, dtype)
|
||||||
|
|
||||||
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
@parameterized.expand(params, name_func=parameterized_custom_name_func)
|
||||||
def test_can_resume_training_errors(self, stage, dtype):
|
def test_can_resume_training_errors(self, stage, dtype):
|
||||||
|
@ -42,51 +42,99 @@ if is_torch_available():
|
|||||||
|
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
|
|
||||||
|
FIXTURE_DIRECTORY = get_tests_dir("fixtures")
|
||||||
|
ROOT_DIRECTORY = os.path.join(dirname(get_tests_dir()))
|
||||||
|
DS_TESTS_DIRECTORY = dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
# default torch.distributed port
|
# default torch.distributed port
|
||||||
DEFAULT_MASTER_PORT = "10999"
|
DEFAULT_MASTER_PORT = "10999"
|
||||||
|
|
||||||
# translation
|
|
||||||
FSMT_TINY = "stas/tiny-wmt19-en-de"
|
|
||||||
BART_TINY = "sshleifer/bart-tiny-random"
|
|
||||||
T5_SMALL = "t5-small"
|
T5_SMALL = "t5-small"
|
||||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
|
||||||
MBART_TINY = "sshleifer/tiny-mbart"
|
|
||||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
|
||||||
|
|
||||||
# summarization
|
# *** Working Models ***
|
||||||
PEGASUS_TINY = "stas/pegasus-cnn_dailymail-tiny-random"
|
ALBERT_TINY = "hf-internal-testing/tiny-albert"
|
||||||
|
BART_TINY = "sshleifer/bart-tiny-random"
|
||||||
# causal lm
|
BERT_TINY = "hf-internal-testing/tiny-bert"
|
||||||
GPT2_TINY = "sshleifer/tiny-gpt2"
|
BIGBIRD_PEGASUS_TINY = "hf-internal-testing/tiny-random-bigbird_pegasus"
|
||||||
XLM_ROBERTA_TINY = "hf-internal-testing/tiny-xlm-roberta"
|
BIG_BIRD_TINY = "hf-internal-testing/tiny-random-big_bird"
|
||||||
|
BLENDERBOT_TINY = "hf-internal-testing/tiny-random-blenderbot"
|
||||||
# question-answering
|
DEBERTA_TINY = "hf-internal-testing/tiny-random-deberta"
|
||||||
ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
|
DEBERTA_V2_TINY = "hf-internal-testing/tiny-random-deberta-v2"
|
||||||
|
|
||||||
# masked lm
|
|
||||||
DISTILBERT_TINY = "sshleifer/tiny-distilbert-base-cased"
|
DISTILBERT_TINY = "sshleifer/tiny-distilbert-base-cased"
|
||||||
ELECTRA_TINY = "hf-internal-testing/tiny-electra"
|
ELECTRA_TINY = "hf-internal-testing/tiny-electra"
|
||||||
|
FLAUBERT_TINY = "hf-internal-testing/tiny-random-flaubert"
|
||||||
# classification
|
FSMT_TINY = "stas/tiny-wmt19-en-de"
|
||||||
|
FUNNEL_TINY = "hf-internal-testing/tiny-random-funnel"
|
||||||
|
GPT2_TINY = "sshleifer/tiny-gpt2"
|
||||||
|
GPTJ_TINY = "hf-internal-testing/tiny-random-gptj"
|
||||||
|
GPT_NEO_TINY = "hf-internal-testing/tiny-random-gpt_neo"
|
||||||
|
LAYOUTLM_TINY = "hf-internal-testing/tiny-layoutlm"
|
||||||
|
LED_TINY = "hf-internal-testing/tiny-random-led"
|
||||||
|
LONGFORMER_TINY = "hf-internal-testing/tiny-random-longformer"
|
||||||
|
M2M_100_TINY = "stas/tiny-m2m_100" # hf tiny model is unsuitable
|
||||||
|
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||||
|
MBART_TINY = "sshleifer/tiny-mbart"
|
||||||
|
MOBILEBERT_TINY = "hf-internal-testing/tiny-random-mobilebert"
|
||||||
|
MPNET_TINY = "hf-internal-testing/tiny-random-mpnet"
|
||||||
|
PEGASUS_TINY = "stas/pegasus-cnn_dailymail-tiny-random"
|
||||||
|
PROPHETNET_TINY = "hf-internal-testing/tiny-random-prophetnet"
|
||||||
|
ROBERTA_TINY = "sshleifer/tiny-distilroberta-base"
|
||||||
|
SQUEEZEBERT_TINY = "hf-internal-testing/tiny-random-squeezebert"
|
||||||
|
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||||
|
T5_V1_TINY = "hf-internal-testing/tiny-random-t5-v1.1"
|
||||||
|
VIT_TINY = "hf-internal-testing/tiny-random-vit"
|
||||||
|
XLM_ROBERTA_TINY = "hf-internal-testing/tiny-xlm-roberta"
|
||||||
XLNET_TINY = "sshleifer/tiny-xlnet-base-cased"
|
XLNET_TINY = "sshleifer/tiny-xlnet-base-cased"
|
||||||
BERT_TINY = "hf-internal-testing/tiny-bert"
|
|
||||||
|
|
||||||
FIXTURE_DIRECTORY = get_tests_dir("fixtures")
|
|
||||||
ROOT_DIRECTORY = os.path.join(dirname(get_tests_dir()))
|
|
||||||
|
|
||||||
# TODO: to add:
|
# *** To Fix ***
|
||||||
# albert
|
|
||||||
# deberta
|
|
||||||
# funnel
|
# *** tiny model issues ***
|
||||||
# longformer
|
# missing model files:
|
||||||
# dpr
|
MT5_TINY = "hf-internal-testing/tiny-random-mt5"
|
||||||
# gpt_neo
|
CAMEMBERT_TINY = "hf-internal-testing/tiny-random-camembert"
|
||||||
# camembert
|
OPENAI_GPT_TINY = "hf-internal-testing/tiny-random-openai-gpt"
|
||||||
# deberta-v2
|
|
||||||
# m2m_100
|
# missing tokenizer files
|
||||||
# tapas
|
CONVBERT_TINY = "hf-internal-testing/tiny-random-convbert"
|
||||||
# vit
|
LAYOUTLMV2_TINY = "hf-internal-testing/tiny-random-layoutlmv2"
|
||||||
# big_bird
|
HUBERT_TINY = "hf-internal-testing/tiny-random-hubert"
|
||||||
|
|
||||||
|
# issues with tokenizer
|
||||||
|
CTRL_TINY = "hf-internal-testing/tiny-random-ctrl"
|
||||||
|
TRANSFO_XL_TINY = "hf-internal-testing/tiny-random-transfo-xl" # same as ctrl
|
||||||
|
|
||||||
|
# other issues with tiny models
|
||||||
|
IBERT_TINY = "hf-internal-testing/tiny-random-ibert" # multiple issues with either mlm/qa/clas
|
||||||
|
REFORMER_TINY = "hf-internal-testing/tiny-random-reformer" # multiple issues with either mlm/qa/clas
|
||||||
|
|
||||||
|
# *** Lacking official examples to test with ***
|
||||||
|
# or not working with examples
|
||||||
|
DPR_TINY = "hf-internal-testing/tiny-random-dpr"
|
||||||
|
# - "dpr" examples/research_projects/rag-end2end-retriever/
|
||||||
|
RAG_TINY = "hf-internal-testing/tiny-random-rag"
|
||||||
|
# - "rag" research_projects
|
||||||
|
LUKE_TINY = ""
|
||||||
|
# - "luke" Entities classes - no plan to make such example
|
||||||
|
LXMERT_TINY = "hf-internal-testing/tiny-random-lxmert"
|
||||||
|
# - "lxmert" doesn't work with run_qa.py
|
||||||
|
CLIP_TINY = "hf-internal-testing/tiny-random-clip"
|
||||||
|
# - "clip" nothing under pytorch examples - XXX: Suraj is working on adding some - check by end of Sep
|
||||||
|
SPEECH_TO_TEXT_TINY = "hf-internal-testing/tiny-random-speech_to_text"
|
||||||
|
# - "speech_to_text", nothing under pytorch examples
|
||||||
|
|
||||||
|
|
||||||
|
# *** Reactive mode ***
|
||||||
|
# models with low usage, unstable API, things about to change - do nothing about the following until someone runs into a problem
|
||||||
|
TAPAS_TINY = "hf-internal-testing/tiny-random-tapas"
|
||||||
|
# additional notes on tapas
|
||||||
|
# 1. requires torch_scatter - skip if it's not installed?
|
||||||
|
# 2. "Table must be of type pd.DataFrame" failure
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: new models to add:
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
def get_launcher(distributed=False):
|
def get_launcher(distributed=False):
|
||||||
@ -113,35 +161,68 @@ def make_task_cmds():
|
|||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
""".split()
|
""".split()
|
||||||
|
|
||||||
# XXX: try to cover as many models as possible once (it's enough to run on one task per model)
|
# try to cover as many models as possible once (it's enough to run on one task per model)
|
||||||
# but need a tiny model for each
|
# but need a tiny model for each
|
||||||
#
|
#
|
||||||
# should have T5_TINY, etc. global var defined
|
# should have "{model_type.upper()}_TINY" corresponding vars defined, e.g., T5_TINY, etc.
|
||||||
tasks2models = dict(
|
tasks2models = dict(
|
||||||
trans=[
|
trans=[
|
||||||
"bart",
|
"bart",
|
||||||
"fsmt",
|
"fsmt",
|
||||||
|
"m2m_100",
|
||||||
"marian",
|
"marian",
|
||||||
"mbart",
|
"mbart",
|
||||||
"t5",
|
"t5",
|
||||||
|
"t5_v1",
|
||||||
|
# "mt5", missing model files
|
||||||
],
|
],
|
||||||
sum=[
|
sum=[
|
||||||
"pegasus",
|
"pegasus",
|
||||||
],
|
],
|
||||||
clm=[
|
clm=[
|
||||||
|
"big_bird",
|
||||||
|
"bigbird_pegasus",
|
||||||
|
"blenderbot",
|
||||||
"gpt2",
|
"gpt2",
|
||||||
|
"gpt_neo",
|
||||||
|
"gptj",
|
||||||
"xlm-roberta",
|
"xlm-roberta",
|
||||||
|
"prophetnet",
|
||||||
|
# "camembert", missing model files
|
||||||
],
|
],
|
||||||
mlm=[
|
mlm=[
|
||||||
"electra",
|
"albert",
|
||||||
|
"deberta",
|
||||||
|
"deberta-v2",
|
||||||
"distilbert",
|
"distilbert",
|
||||||
|
"electra",
|
||||||
|
"flaubert",
|
||||||
|
"funnel",
|
||||||
|
"layoutlm",
|
||||||
|
# "reformer", # multiple issues with either mlm/qa/clas
|
||||||
],
|
],
|
||||||
qa=[
|
qa=[
|
||||||
|
"led",
|
||||||
|
"longformer",
|
||||||
|
"mobilebert",
|
||||||
|
"mpnet",
|
||||||
"roberta",
|
"roberta",
|
||||||
|
"squeezebert",
|
||||||
|
# "convbert", # missing tokenizer files
|
||||||
|
# "layoutlmv2", missing model files
|
||||||
],
|
],
|
||||||
clas=[
|
clas=[
|
||||||
"bert",
|
"bert",
|
||||||
"xlnet",
|
"xlnet",
|
||||||
|
# "hubert", # missing tokenizer files
|
||||||
|
# "ibert", # multiple issues with either mlm/qa/clas
|
||||||
|
# "transfo-xl", # tokenizer issues as ctrl
|
||||||
|
# "ctrl", # tokenizer issues
|
||||||
|
# "openai-gpt", missing model files
|
||||||
|
# "tapas", multiple issues
|
||||||
|
],
|
||||||
|
img_clas=[
|
||||||
|
"vit",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -180,6 +261,13 @@ def make_task_cmds():
|
|||||||
--max_seq_length 12
|
--max_seq_length 12
|
||||||
--task_name MRPC
|
--task_name MRPC
|
||||||
""",
|
""",
|
||||||
|
img_clas=f"""
|
||||||
|
{scripts_dir}/image-classification/run_image_classification.py
|
||||||
|
--dataset_name hf-internal-testing/cats_vs_dogs_sample
|
||||||
|
--remove_unused_columns False
|
||||||
|
--max_steps 10
|
||||||
|
--feature_extractor_name {DS_TESTS_DIRECTORY}/vit_feature_extractor.json
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
launcher = get_launcher(distributed=True)
|
launcher = get_launcher(distributed=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user