From f7d80cb3d25e26cfae7fa5c91bc6fe8552ed68ea Mon Sep 17 00:00:00 2001 From: Ethan Date: Mon, 12 Jun 2023 23:49:55 +0800 Subject: [PATCH] Fix steps bugs in no trainer examples (#24197) Fix step bugs in no trainer + load checkpoint + grad acc --- .../run_image_classification_no_trainer.py | 5 +++-- examples/pytorch/image-pretraining/run_mim_no_trainer.py | 2 +- examples/pytorch/language-modeling/run_clm_no_trainer.py | 2 +- examples/pytorch/language-modeling/run_mlm_no_trainer.py | 2 +- examples/pytorch/multiple-choice/run_swag_no_trainer.py | 5 +++-- .../question-answering/run_qa_beam_search_no_trainer.py | 5 +++-- examples/pytorch/question-answering/run_qa_no_trainer.py | 2 +- .../run_semantic_segmentation_no_trainer.py | 5 +++-- .../pytorch/summarization/run_summarization_no_trainer.py | 5 +++-- .../pytorch/text-classification/run_glue_no_trainer.py | 8 +++++++- .../pytorch/token-classification/run_ner_no_trainer.py | 8 +++++++- .../pytorch/translation/run_translation_no_trainer.py | 2 +- 12 files changed, 34 insertions(+), 17 deletions(-) diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py index a072160c73b..f447b5a7a27 100644 --- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py +++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py @@ -453,10 +453,11 @@ def main(): resume_step = None completed_steps = starting_epoch * num_update_steps_per_epoch else: - resume_step = int(training_difference.replace("step_", "")) + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_step # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) diff --git a/examples/pytorch/image-pretraining/run_mim_no_trainer.py b/examples/pytorch/image-pretraining/run_mim_no_trainer.py index 870be3dfa13..93782ae4b8c 100644 --- a/examples/pytorch/image-pretraining/run_mim_no_trainer.py +++ b/examples/pytorch/image-pretraining/run_mim_no_trainer.py @@ -666,7 +666,7 @@ def main(): resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_steps # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index c38102f6c5c..1ec311289aa 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -572,7 +572,7 @@ def main(): resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_steps # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index ccc4e0a0987..760181feffc 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -616,7 +616,7 @@ def main(): resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_steps # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) diff --git a/examples/pytorch/multiple-choice/run_swag_no_trainer.py b/examples/pytorch/multiple-choice/run_swag_no_trainer.py index 6999b9c800d..ca6f49a4ab1 100755 --- a/examples/pytorch/multiple-choice/run_swag_no_trainer.py +++ b/examples/pytorch/multiple-choice/run_swag_no_trainer.py @@ -559,10 +559,11 @@ def main(): resume_step = None completed_steps = starting_epoch * num_update_steps_per_epoch else: - resume_step = int(training_difference.replace("step_", "")) + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_stepp # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) diff --git a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py index d1ace742471..3cc6a9686a2 100644 --- a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py +++ b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py @@ -811,10 +811,11 @@ def main(): resume_step = None completed_steps = starting_epoch * num_update_steps_per_epoch else: - resume_step = int(training_difference.replace("step_", "")) + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_stepp # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) diff --git a/examples/pytorch/question-answering/run_qa_no_trainer.py b/examples/pytorch/question-answering/run_qa_no_trainer.py index 95564db9679..60e54ad8c5e 100755 --- a/examples/pytorch/question-answering/run_qa_no_trainer.py +++ b/examples/pytorch/question-answering/run_qa_no_trainer.py @@ -830,7 +830,7 @@ def main(): resume_step = int(training_difference.replace("step_", "")) starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_stepp # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py index 740c2ea5cbb..fe0ae637936 100644 --- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py +++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py @@ -556,10 +556,11 @@ def main(): resume_step = None completed_steps = starting_epoch * num_update_steps_per_epoch else: - resume_step = int(training_difference.replace("step_", "")) + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_stepp # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index 7cf627046e5..7b5ce25d880 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -628,10 +628,11 @@ def main(): resume_step = None completed_steps = starting_epoch * num_update_steps_per_epoch else: - resume_step = int(training_difference.replace("step_", "")) + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_stepp # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py index 0c2b84ec165..ff8b65ef6c1 100644 --- a/examples/pytorch/text-classification/run_glue_no_trainer.py +++ b/examples/pytorch/text-classification/run_glue_no_trainer.py @@ -501,10 +501,16 @@ def main(): if "epoch" in training_difference: starting_epoch = int(training_difference.replace("epoch_", "")) + 1 resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch else: - resume_step = int(training_difference.replace("step_", "")) + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) + completed_steps = resume_step // args.gradient_accumulation_step + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) for epoch in range(starting_epoch, args.num_train_epochs): model.train() diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index dcda152dbe3..84f506a42ca 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -659,10 +659,16 @@ def main(): if "epoch" in training_difference: starting_epoch = int(training_difference.replace("epoch_", "")) + 1 resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch else: - resume_step = int(training_difference.replace("step_", "")) + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) + completed_steps = resume_step // args.gradient_accumulation_stepp + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) for epoch in range(starting_epoch, args.num_train_epochs): model.train() diff --git a/examples/pytorch/translation/run_translation_no_trainer.py b/examples/pytorch/translation/run_translation_no_trainer.py index d42ade1f59f..3faf99cc245 100644 --- a/examples/pytorch/translation/run_translation_no_trainer.py +++ b/examples/pytorch/translation/run_translation_no_trainer.py @@ -613,7 +613,7 @@ def main(): resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps starting_epoch = resume_step // len(train_dataloader) resume_step -= starting_epoch * len(train_dataloader) - completed_steps = resume_step + completed_steps = resume_step // args.gradient_accumulation_stepp # update the progress_bar if load from checkpoint progress_bar.update(completed_steps)