Add tests for no_trainer and fix existing examples (#16656)

* Fixed some bugs involving saving during epochs
* Added tests mimicking the existing examples tests
* Added in json exporting to all `no_trainer` examples for consistency
This commit is contained in:
Zachary Mueller 2022-04-08 10:03:56 -04:00 committed by GitHub
parent ab229663b5
commit d57da99237
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 414 additions and 22 deletions

View File

@ -587,6 +587,7 @@ jobs:
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[sklearn,torch,sentencepiece,testing,torch-speech] - run: pip install .[sklearn,torch,sentencepiece,testing,torch-speech]
- run: pip install -r examples/pytorch/_tests_requirements.txt - run: pip install -r examples/pytorch/_tests_requirements.txt
- run: pip install git+https://github.com/huggingface/accelerate
- save_cache: - save_cache:
key: v0.4-torch_examples-{{ checksum "setup.py" }} key: v0.4-torch_examples-{{ checksum "setup.py" }}
paths: paths:

View File

@ -23,6 +23,7 @@ https://huggingface.co/models?filter=text-generation
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import argparse import argparse
import json
import logging import logging
import math import math
import os import os
@ -537,7 +538,10 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}") output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
@ -581,7 +585,10 @@ def main():
) )
if args.checkpointing_steps == "epoch": if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}") output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@ -592,6 +599,9 @@ def main():
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"perplexity": perplexity}, f)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -23,6 +23,7 @@ https://huggingface.co/models?filter=fill-mask
# You can also adapt this script on your own mlm task. Pointers for this are left as comments. # You can also adapt this script on your own mlm task. Pointers for this are left as comments.
import argparse import argparse
import json
import logging import logging
import math import math
import os import os
@ -457,6 +458,8 @@ def main():
train_dataset = tokenized_datasets["train"] train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"] eval_dataset = tokenized_datasets["validation"]
# Conditional for small test subsets
if len(train_dataset) > 3:
# Log a few random samples from the training set: # Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3): for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
@ -581,7 +584,10 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}") output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
@ -625,7 +631,10 @@ def main():
) )
if args.checkpointing_steps == "epoch": if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}") output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@ -636,6 +645,9 @@ def main():
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"perplexity": perplexity}, f)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model on multiple choice relying on the accelera
# You can also adapt this script on your own multiple choice task. Pointers for this are left as comments. # You can also adapt this script on your own multiple choice task. Pointers for this are left as comments.
import argparse import argparse
import json
import logging import logging
import math import math
import os import os
@ -540,7 +541,10 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}") output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
@ -578,6 +582,12 @@ def main():
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
) )
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
@ -586,6 +596,8 @@ def main():
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model for question answering using 🤗 Accelera
# You can also adapt this script on your own question answering task. Pointers for this are left as comments. # You can also adapt this script on your own question answering task. Pointers for this are left as comments.
import argparse import argparse
import json
import logging import logging
import math import math
import os import os
@ -783,11 +784,20 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}") output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.push_to_hub and epoch < args.num_train_epochs - 1: if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
@ -879,9 +889,6 @@ def main():
accelerator.log(log, step=completed_steps) accelerator.log(log, step=completed_steps)
if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}")
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model) unwrapped_model = accelerator.unwrap_model(model)
@ -890,6 +897,8 @@ def main():
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_f1": eval_metric["f1"], "eval_exact": eval_metric["exact"]}, f)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model on summarization.
# You can also adapt this script on your own summarization task. Pointers for this are left as comments. # You can also adapt this script on your own summarization task. Pointers for this are left as comments.
import argparse import argparse
import json
import logging import logging
import math import math
import os import os
@ -602,7 +603,10 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}") output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
@ -669,7 +673,10 @@ def main():
) )
if args.checkpointing_steps == "epoch": if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}") output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@ -679,6 +686,16 @@ def main():
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(
{
"eval_rouge1": result["rouge1"],
"eval_rouge2": result["rouge2"],
"eval_rougeL": result["rougeL"],
"eval_rougeLsum": result["rougeLsum"],
},
f,
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,302 @@
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import os
import sys
from unittest.mock import patch
import torch
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
from transformers.utils import is_apex_available
SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
for dirname in [
"text-generation",
"text-classification",
"token-classification",
"language-modeling",
"multiple-choice",
"question-answering",
"summarization",
"translation",
"image-classification",
"speech-recognition",
"audio-classification",
"speech-pretraining",
"image-pretraining",
]
]
sys.path.extend(SRC_DIRS)
if SRC_DIRS is not None:
import run_clm_no_trainer
import run_glue_no_trainer
import run_mlm_no_trainer
import run_ner_no_trainer
import run_qa_no_trainer as run_squad_no_trainer
import run_summarization_no_trainer
import run_swag_no_trainer
import run_translation_no_trainer
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument("-f")
args = parser.parse_args()
return args.f
def get_results(output_dir):
results = {}
path = os.path.join(output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
results = json.load(f)
else:
raise ValueError(f"can't find {path}")
return results
def is_cuda_and_apex_available():
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
return is_using_cuda and is_apex_available()
class ExamplesTestsNoTrainer(TestCasePlus):
def test_run_glue_no_trainer(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue_no_trainer.py
--model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir}
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=1e-4
--seed=42
--checkpointing_steps epoch
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_glue_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
def test_run_clm_no_trainer(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_clm_no_trainer.py
--model_name_or_path distilgpt2
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--block_size 128
--per_device_train_batch_size 5
--per_device_eval_batch_size 5
--num_train_epochs 2
--output_dir {tmp_dir}
--checkpointing_steps epoch
""".split()
if torch.cuda.device_count() > 1:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
with patch.object(sys, "argv", testargs):
run_clm_no_trainer.main()
result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 100)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
def test_run_mlm_no_trainer(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_mlm_no_trainer.py
--model_name_or_path distilroberta-base
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--output_dir {tmp_dir}
--num_train_epochs=1
--checkpointing_steps epoch
""".split()
with patch.object(sys, "argv", testargs):
run_mlm_no_trainer.main()
result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 42)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
def test_run_ner_no_trainer(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_ner_no_trainer.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json
--output_dir {tmp_dir}
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=2
--num_train_epochs={epochs}
--seed 7
--checkpointing_steps epoch
""".split()
with patch.object(sys, "argv", testargs):
run_ner_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertLess(result["train_loss"], 0.5)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
def test_run_squad_no_trainer(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_qa_no_trainer.py
--model_name_or_path bert-base-uncased
--version_2_with_negative=False
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
--output_dir {tmp_dir}
--max_train_steps=10
--num_warmup_steps=2
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--checkpointing_steps epoch
""".split()
with patch.object(sys, "argv", testargs):
run_squad_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
def test_run_swag_no_trainer(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_swag_no_trainer.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/swag/sample.json
--validation_file tests/fixtures/tests_samples/swag/sample.json
--output_dir {tmp_dir}
--max_train_steps=20
--num_warmup_steps=2
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
""".split()
with patch.object(sys, "argv", testargs):
run_swag_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
@slow
def test_run_summarization_no_trainer(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_summarization_no_trainer.py
--model_name_or_path t5-small
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir {tmp_dir}
--max_train_steps=50
--num_warmup_steps=8
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--checkpointing_steps epoch
""".split()
with patch.object(sys, "argv", testargs):
run_summarization_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 2)
self.assertGreaterEqual(result["eval_rougeL"], 7)
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
@slow
def test_run_translation_no_trainer(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_translation_no_trainer.py
--model_name_or_path sshleifer/student_marian_en_ro_6_1
--source_lang en
--target_lang ro
--train_file tests/fixtures/tests_samples/wmt16/sample.json
--validation_file tests/fixtures/tests_samples/wmt16/sample.json
--output_dir {tmp_dir}
--max_train_steps=50
--num_warmup_steps=8
--learning_rate=3e-3
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--source_lang en_XX
--target_lang ro_RO
--checkpointing_steps epoch
""".split()
with patch.object(sys, "argv", testargs):
run_translation_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" Finetuning a 🤗 Transformers model for sequence classification on GLUE.""" """ Finetuning a 🤗 Transformers model for sequence classification on GLUE."""
import argparse import argparse
import json
import logging import logging
import math import math
import os import os
@ -150,7 +151,6 @@ def parse_args():
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
) )
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument( parser.add_argument(
"--checkpointing_steps", "--checkpointing_steps",
type=str, type=str,
@ -488,7 +488,10 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}") output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
@ -526,7 +529,10 @@ def main():
) )
if args.checkpointing_steps == "epoch": if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}") output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@ -557,6 +563,10 @@ def main():
eval_metric = metric.compute() eval_metric = metric.compute()
logger.info(f"mnli-mm: {eval_metric}") logger.info(f"mnli-mm: {eval_metric}")
if args.output_dir is not None:
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -19,6 +19,7 @@ without using a Trainer.
""" """
import argparse import argparse
import json
import logging import logging
import math import math
import os import os
@ -639,7 +640,10 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}") output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
@ -662,7 +666,6 @@ def main():
references=refs, references=refs,
) # predictions and preferences are expected to be a nested list of labels, not label_ids ) # predictions and preferences are expected to be a nested list of labels, not label_ids
# eval_metric = metric.compute()
eval_metric = compute_metrics() eval_metric = compute_metrics()
accelerator.print(f"epoch {epoch}:", eval_metric) accelerator.print(f"epoch {epoch}:", eval_metric)
if args.with_tracking: if args.with_tracking:
@ -686,7 +689,10 @@ def main():
) )
if args.checkpointing_steps == "epoch": if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}") output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@ -697,6 +703,9 @@ def main():
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_accuracy": eval_metric["accuracy"], "train_loss": float(loss.cpu().detach().numpy())}, f)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model on text translation.
# You can also adapt this script on your own text translation task. Pointers for this are left as comments. # You can also adapt this script on your own text translation task. Pointers for this are left as comments.
import argparse import argparse
import json
import logging import logging
import math import math
import os import os
@ -586,7 +587,10 @@ def main():
if isinstance(checkpointing_steps, int): if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0: if completed_steps % checkpointing_steps == 0:
accelerator.save_state(f"step_{completed_steps}") output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
@ -653,7 +657,10 @@ def main():
) )
if args.checkpointing_steps == "epoch": if args.checkpointing_steps == "epoch":
accelerator.save_state(f"epoch_{epoch}") output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.output_dir is not None: if args.output_dir is not None:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@ -663,6 +670,8 @@ def main():
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_bleu": eval_metric["score"]}, f)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -466,6 +466,7 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
# Example files are tested separately # Example files are tested separately
elif f.startswith("examples/pytorch"): elif f.startswith("examples/pytorch"):
test_files_to_run.append("examples/pytorch/test_pytorch_examples.py") test_files_to_run.append("examples/pytorch/test_pytorch_examples.py")
test_files_to_run.append("examples/pytorch/test_accelerate_examples.py")
elif f.startswith("examples/flax"): elif f.startswith("examples/flax"):
test_files_to_run.append("examples/flax/test_flax_examples.py") test_files_to_run.append("examples/flax/test_flax_examples.py")
else: else: