diff --git a/.circleci/config.yml b/.circleci/config.yml index 856211e280c..1bfe5d29f7f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -587,6 +587,7 @@ jobs: - run: pip install --upgrade pip - run: pip install .[sklearn,torch,sentencepiece,testing,torch-speech] - run: pip install -r examples/pytorch/_tests_requirements.txt + - run: pip install git+https://github.com/huggingface/accelerate - save_cache: key: v0.4-torch_examples-{{ checksum "setup.py" }} paths: diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index 76eca548693..247ba09d54a 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -23,6 +23,7 @@ https://huggingface.co/models?filter=text-generation # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. import argparse +import json import logging import math import os @@ -537,7 +538,10 @@ def main(): if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - accelerator.save_state(f"step_{completed_steps}") + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if completed_steps >= args.max_train_steps: break @@ -581,7 +585,10 @@ def main(): ) if args.checkpointing_steps == "epoch": - accelerator.save_state(f"epoch_{epoch}") + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if args.output_dir is not None: accelerator.wait_for_everyone() @@ -592,6 +599,9 @@ def main(): if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"perplexity": perplexity}, f) + if __name__ == "__main__": main() diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 6a3b48c3b1c..2634cc25e5b 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -23,6 +23,7 @@ https://huggingface.co/models?filter=fill-mask # You can also adapt this script on your own mlm task. Pointers for this are left as comments. import argparse +import json import logging import math import os @@ -457,9 +458,11 @@ def main(): train_dataset = tokenized_datasets["train"] eval_dataset = tokenized_datasets["validation"] - # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + # Conditional for small test subsets + if len(train_dataset) > 3: + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Data collator # This one will take care of randomly masking the tokens. @@ -581,7 +584,10 @@ def main(): if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - accelerator.save_state(f"step_{completed_steps}") + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if completed_steps >= args.max_train_steps: break @@ -625,7 +631,10 @@ def main(): ) if args.checkpointing_steps == "epoch": - accelerator.save_state(f"epoch_{epoch}") + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if args.output_dir is not None: accelerator.wait_for_everyone() @@ -636,6 +645,9 @@ def main(): if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"perplexity": perplexity}, f) + if __name__ == "__main__": main() diff --git a/examples/pytorch/multiple-choice/run_swag_no_trainer.py b/examples/pytorch/multiple-choice/run_swag_no_trainer.py index d97fb71f395..a575644130f 100755 --- a/examples/pytorch/multiple-choice/run_swag_no_trainer.py +++ b/examples/pytorch/multiple-choice/run_swag_no_trainer.py @@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model on multiple choice relying on the accelera # You can also adapt this script on your own multiple choice task. Pointers for this are left as comments. import argparse +import json import logging import math import os @@ -540,7 +541,10 @@ def main(): if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - accelerator.save_state(f"step_{completed_steps}") + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if completed_steps >= args.max_train_steps: break @@ -578,6 +582,12 @@ def main(): commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True ) + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) @@ -586,6 +596,8 @@ def main(): tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"eval_accuracy": eval_metric["accuracy"]}, f) if __name__ == "__main__": diff --git a/examples/pytorch/question-answering/run_qa_no_trainer.py b/examples/pytorch/question-answering/run_qa_no_trainer.py index 08f8339036c..6da75822398 100755 --- a/examples/pytorch/question-answering/run_qa_no_trainer.py +++ b/examples/pytorch/question-answering/run_qa_no_trainer.py @@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model for question answering using 🤗 Accelera # You can also adapt this script on your own question answering task. Pointers for this are left as comments. import argparse +import json import logging import math import os @@ -783,11 +784,20 @@ def main(): if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - accelerator.save_state(f"step_{completed_steps}") + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if completed_steps >= args.max_train_steps: break + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) @@ -879,9 +889,6 @@ def main(): accelerator.log(log, step=completed_steps) - if args.checkpointing_steps == "epoch": - accelerator.save_state(f"epoch_{epoch}") - if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) @@ -890,6 +897,8 @@ def main(): tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"eval_f1": eval_metric["f1"], "eval_exact": eval_metric["exact"]}, f) if __name__ == "__main__": diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index fd2bb2cc816..adc9e616dda 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model on summarization. # You can also adapt this script on your own summarization task. Pointers for this are left as comments. import argparse +import json import logging import math import os @@ -602,7 +603,10 @@ def main(): if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - accelerator.save_state(f"step_{completed_steps}") + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if completed_steps >= args.max_train_steps: break @@ -669,7 +673,10 @@ def main(): ) if args.checkpointing_steps == "epoch": - accelerator.save_state(f"epoch_{epoch}") + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if args.output_dir is not None: accelerator.wait_for_everyone() @@ -679,6 +686,16 @@ def main(): tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump( + { + "eval_rouge1": result["rouge1"], + "eval_rouge2": result["rouge2"], + "eval_rougeL": result["rougeL"], + "eval_rougeLsum": result["rougeLsum"], + }, + f, + ) if __name__ == "__main__": diff --git a/examples/pytorch/test_accelerate_examples.py b/examples/pytorch/test_accelerate_examples.py new file mode 100644 index 00000000000..883dc434deb --- /dev/null +++ b/examples/pytorch/test_accelerate_examples.py @@ -0,0 +1,302 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc.. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import json +import logging +import os +import sys +from unittest.mock import patch + +import torch + +from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device +from transformers.utils import is_apex_available + + +SRC_DIRS = [ + os.path.join(os.path.dirname(__file__), dirname) + for dirname in [ + "text-generation", + "text-classification", + "token-classification", + "language-modeling", + "multiple-choice", + "question-answering", + "summarization", + "translation", + "image-classification", + "speech-recognition", + "audio-classification", + "speech-pretraining", + "image-pretraining", + ] +] +sys.path.extend(SRC_DIRS) + + +if SRC_DIRS is not None: + import run_clm_no_trainer + import run_glue_no_trainer + import run_mlm_no_trainer + import run_ner_no_trainer + import run_qa_no_trainer as run_squad_no_trainer + import run_summarization_no_trainer + import run_swag_no_trainer + import run_translation_no_trainer + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() + + +def get_setup_file(): + parser = argparse.ArgumentParser() + parser.add_argument("-f") + args = parser.parse_args() + return args.f + + +def get_results(output_dir): + results = {} + path = os.path.join(output_dir, "all_results.json") + if os.path.exists(path): + with open(path, "r") as f: + results = json.load(f) + else: + raise ValueError(f"can't find {path}") + return results + + +def is_cuda_and_apex_available(): + is_using_cuda = torch.cuda.is_available() and torch_device == "cuda" + return is_using_cuda and is_apex_available() + + +class ExamplesTestsNoTrainer(TestCasePlus): + def test_run_glue_no_trainer(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_glue_no_trainer.py + --model_name_or_path distilbert-base-uncased + --output_dir {tmp_dir} + --train_file ./tests/fixtures/tests_samples/MRPC/train.csv + --validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + --learning_rate=1e-4 + --seed=42 + --checkpointing_steps epoch + """.split() + + if is_cuda_and_apex_available(): + testargs.append("--fp16") + + with patch.object(sys, "argv", testargs): + run_glue_no_trainer.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_accuracy"], 0.75) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) + + def test_run_clm_no_trainer(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_clm_no_trainer.py + --model_name_or_path distilgpt2 + --train_file ./tests/fixtures/sample_text.txt + --validation_file ./tests/fixtures/sample_text.txt + --block_size 128 + --per_device_train_batch_size 5 + --per_device_eval_batch_size 5 + --num_train_epochs 2 + --output_dir {tmp_dir} + --checkpointing_steps epoch + """.split() + + if torch.cuda.device_count() > 1: + # Skipping because there are not enough batches to train the model + would need a drop_last to work. + return + + with patch.object(sys, "argv", testargs): + run_clm_no_trainer.main() + result = get_results(tmp_dir) + self.assertLess(result["perplexity"], 100) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) + + def test_run_mlm_no_trainer(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_mlm_no_trainer.py + --model_name_or_path distilroberta-base + --train_file ./tests/fixtures/sample_text.txt + --validation_file ./tests/fixtures/sample_text.txt + --output_dir {tmp_dir} + --num_train_epochs=1 + --checkpointing_steps epoch + """.split() + + with patch.object(sys, "argv", testargs): + run_mlm_no_trainer.main() + result = get_results(tmp_dir) + self.assertLess(result["perplexity"], 42) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) + + def test_run_ner_no_trainer(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu + epochs = 7 if get_gpu_count() > 1 else 2 + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_ner_no_trainer.py + --model_name_or_path bert-base-uncased + --train_file tests/fixtures/tests_samples/conll/sample.json + --validation_file tests/fixtures/tests_samples/conll/sample.json + --output_dir {tmp_dir} + --learning_rate=2e-4 + --per_device_train_batch_size=2 + --per_device_eval_batch_size=2 + --num_train_epochs={epochs} + --seed 7 + --checkpointing_steps epoch + """.split() + + with patch.object(sys, "argv", testargs): + run_ner_no_trainer.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_accuracy"], 0.75) + self.assertLess(result["train_loss"], 0.5) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) + + def test_run_squad_no_trainer(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_qa_no_trainer.py + --model_name_or_path bert-base-uncased + --version_2_with_negative=False + --train_file tests/fixtures/tests_samples/SQUAD/sample.json + --validation_file tests/fixtures/tests_samples/SQUAD/sample.json + --output_dir {tmp_dir} + --max_train_steps=10 + --num_warmup_steps=2 + --learning_rate=2e-4 + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + --checkpointing_steps epoch + """.split() + + with patch.object(sys, "argv", testargs): + run_squad_no_trainer.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_f1"], 30) + self.assertGreaterEqual(result["eval_exact"], 30) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) + + def test_run_swag_no_trainer(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_swag_no_trainer.py + --model_name_or_path bert-base-uncased + --train_file tests/fixtures/tests_samples/swag/sample.json + --validation_file tests/fixtures/tests_samples/swag/sample.json + --output_dir {tmp_dir} + --max_train_steps=20 + --num_warmup_steps=2 + --learning_rate=2e-4 + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + """.split() + + with patch.object(sys, "argv", testargs): + run_swag_no_trainer.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_accuracy"], 0.8) + + @slow + def test_run_summarization_no_trainer(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_summarization_no_trainer.py + --model_name_or_path t5-small + --train_file tests/fixtures/tests_samples/xsum/sample.json + --validation_file tests/fixtures/tests_samples/xsum/sample.json + --output_dir {tmp_dir} + --max_train_steps=50 + --num_warmup_steps=8 + --learning_rate=2e-4 + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + --checkpointing_steps epoch + """.split() + + with patch.object(sys, "argv", testargs): + run_summarization_no_trainer.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_rouge1"], 10) + self.assertGreaterEqual(result["eval_rouge2"], 2) + self.assertGreaterEqual(result["eval_rougeL"], 7) + self.assertGreaterEqual(result["eval_rougeLsum"], 7) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) + + @slow + def test_run_translation_no_trainer(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_translation_no_trainer.py + --model_name_or_path sshleifer/student_marian_en_ro_6_1 + --source_lang en + --target_lang ro + --train_file tests/fixtures/tests_samples/wmt16/sample.json + --validation_file tests/fixtures/tests_samples/wmt16/sample.json + --output_dir {tmp_dir} + --max_train_steps=50 + --num_warmup_steps=8 + --learning_rate=3e-3 + --per_device_train_batch_size=2 + --per_device_eval_batch_size=1 + --source_lang en_XX + --target_lang ro_RO + --checkpointing_steps epoch + """.split() + + with patch.object(sys, "argv", testargs): + run_translation_no_trainer.main() + result = get_results(tmp_dir) + self.assertGreaterEqual(result["eval_bleu"], 30) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0"))) diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py index 5bd1d1fa1e5..2c7fa186d0e 100644 --- a/examples/pytorch/text-classification/run_glue_no_trainer.py +++ b/examples/pytorch/text-classification/run_glue_no_trainer.py @@ -14,6 +14,7 @@ # limitations under the License. """ Finetuning a 🤗 Transformers model for sequence classification on GLUE.""" import argparse +import json import logging import math import os @@ -150,7 +151,6 @@ def parse_args(): "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." ) parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") - parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") parser.add_argument( "--checkpointing_steps", type=str, @@ -488,7 +488,10 @@ def main(): if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - accelerator.save_state(f"step_{completed_steps}") + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if completed_steps >= args.max_train_steps: break @@ -526,7 +529,10 @@ def main(): ) if args.checkpointing_steps == "epoch": - accelerator.save_state(f"epoch_{epoch}") + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if args.output_dir is not None: accelerator.wait_for_everyone() @@ -557,6 +563,10 @@ def main(): eval_metric = metric.compute() logger.info(f"mnli-mm: {eval_metric}") + if args.output_dir is not None: + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"eval_accuracy": eval_metric["accuracy"]}, f) + if __name__ == "__main__": main() diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index 57d3ceee905..ab9fcce6df9 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -19,6 +19,7 @@ without using a Trainer. """ import argparse +import json import logging import math import os @@ -639,7 +640,10 @@ def main(): if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - accelerator.save_state(f"step_{completed_steps}") + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if completed_steps >= args.max_train_steps: break @@ -662,7 +666,6 @@ def main(): references=refs, ) # predictions and preferences are expected to be a nested list of labels, not label_ids - # eval_metric = metric.compute() eval_metric = compute_metrics() accelerator.print(f"epoch {epoch}:", eval_metric) if args.with_tracking: @@ -686,7 +689,10 @@ def main(): ) if args.checkpointing_steps == "epoch": - accelerator.save_state(f"epoch_{epoch}") + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if args.output_dir is not None: accelerator.wait_for_everyone() @@ -697,6 +703,9 @@ def main(): if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"eval_accuracy": eval_metric["accuracy"], "train_loss": float(loss.cpu().detach().numpy())}, f) + if __name__ == "__main__": main() diff --git a/examples/pytorch/translation/run_translation_no_trainer.py b/examples/pytorch/translation/run_translation_no_trainer.py index bf7e15ae4dd..034387582b8 100644 --- a/examples/pytorch/translation/run_translation_no_trainer.py +++ b/examples/pytorch/translation/run_translation_no_trainer.py @@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model on text translation. # You can also adapt this script on your own text translation task. Pointers for this are left as comments. import argparse +import json import logging import math import os @@ -586,7 +587,10 @@ def main(): if isinstance(checkpointing_steps, int): if completed_steps % checkpointing_steps == 0: - accelerator.save_state(f"step_{completed_steps}") + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if completed_steps >= args.max_train_steps: break @@ -653,7 +657,10 @@ def main(): ) if args.checkpointing_steps == "epoch": - accelerator.save_state(f"epoch_{epoch}") + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) if args.output_dir is not None: accelerator.wait_for_everyone() @@ -663,6 +670,8 @@ def main(): tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"eval_bleu": eval_metric["score"]}, f) if __name__ == "__main__": diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 5ad7b4b1f78..16bf6348d38 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -466,6 +466,7 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None): # Example files are tested separately elif f.startswith("examples/pytorch"): test_files_to_run.append("examples/pytorch/test_pytorch_examples.py") + test_files_to_run.append("examples/pytorch/test_accelerate_examples.py") elif f.startswith("examples/flax"): test_files_to_run.append("examples/flax/test_flax_examples.py") else: