mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
Add tests for no_trainer and fix existing examples (#16656)
* Fixed some bugs involving saving during epochs * Added tests mimicking the existing examples tests * Added in json exporting to all `no_trainer` examples for consistency
This commit is contained in:
parent
ab229663b5
commit
d57da99237
@ -587,6 +587,7 @@ jobs:
|
|||||||
- run: pip install --upgrade pip
|
- run: pip install --upgrade pip
|
||||||
- run: pip install .[sklearn,torch,sentencepiece,testing,torch-speech]
|
- run: pip install .[sklearn,torch,sentencepiece,testing,torch-speech]
|
||||||
- run: pip install -r examples/pytorch/_tests_requirements.txt
|
- run: pip install -r examples/pytorch/_tests_requirements.txt
|
||||||
|
- run: pip install git+https://github.com/huggingface/accelerate
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: v0.4-torch_examples-{{ checksum "setup.py" }}
|
key: v0.4-torch_examples-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
|
@ -23,6 +23,7 @@ https://huggingface.co/models?filter=text-generation
|
|||||||
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -537,7 +538,10 @@ def main():
|
|||||||
|
|
||||||
if isinstance(checkpointing_steps, int):
|
if isinstance(checkpointing_steps, int):
|
||||||
if completed_steps % checkpointing_steps == 0:
|
if completed_steps % checkpointing_steps == 0:
|
||||||
accelerator.save_state(f"step_{completed_steps}")
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if completed_steps >= args.max_train_steps:
|
if completed_steps >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@ -581,7 +585,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.checkpointing_steps == "epoch":
|
if args.checkpointing_steps == "epoch":
|
||||||
accelerator.save_state(f"epoch_{epoch}")
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
@ -592,6 +599,9 @@ def main():
|
|||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
|
|
||||||
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
|
json.dump({"perplexity": perplexity}, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -23,6 +23,7 @@ https://huggingface.co/models?filter=fill-mask
|
|||||||
# You can also adapt this script on your own mlm task. Pointers for this are left as comments.
|
# You can also adapt this script on your own mlm task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -457,6 +458,8 @@ def main():
|
|||||||
train_dataset = tokenized_datasets["train"]
|
train_dataset = tokenized_datasets["train"]
|
||||||
eval_dataset = tokenized_datasets["validation"]
|
eval_dataset = tokenized_datasets["validation"]
|
||||||
|
|
||||||
|
# Conditional for small test subsets
|
||||||
|
if len(train_dataset) > 3:
|
||||||
# Log a few random samples from the training set:
|
# Log a few random samples from the training set:
|
||||||
for index in random.sample(range(len(train_dataset)), 3):
|
for index in random.sample(range(len(train_dataset)), 3):
|
||||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||||
@ -581,7 +584,10 @@ def main():
|
|||||||
|
|
||||||
if isinstance(checkpointing_steps, int):
|
if isinstance(checkpointing_steps, int):
|
||||||
if completed_steps % checkpointing_steps == 0:
|
if completed_steps % checkpointing_steps == 0:
|
||||||
accelerator.save_state(f"step_{completed_steps}")
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if completed_steps >= args.max_train_steps:
|
if completed_steps >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@ -625,7 +631,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.checkpointing_steps == "epoch":
|
if args.checkpointing_steps == "epoch":
|
||||||
accelerator.save_state(f"epoch_{epoch}")
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
@ -636,6 +645,9 @@ def main():
|
|||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
|
|
||||||
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
|
json.dump({"perplexity": perplexity}, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model on multiple choice relying on the accelera
|
|||||||
# You can also adapt this script on your own multiple choice task. Pointers for this are left as comments.
|
# You can also adapt this script on your own multiple choice task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -540,7 +541,10 @@ def main():
|
|||||||
|
|
||||||
if isinstance(checkpointing_steps, int):
|
if isinstance(checkpointing_steps, int):
|
||||||
if completed_steps % checkpointing_steps == 0:
|
if completed_steps % checkpointing_steps == 0:
|
||||||
accelerator.save_state(f"step_{completed_steps}")
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if completed_steps >= args.max_train_steps:
|
if completed_steps >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@ -578,6 +582,12 @@ def main():
|
|||||||
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
|
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.checkpointing_steps == "epoch":
|
||||||
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
@ -586,6 +596,8 @@ def main():
|
|||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
|
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model for question answering using 🤗 Accelera
|
|||||||
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
|
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -783,11 +784,20 @@ def main():
|
|||||||
|
|
||||||
if isinstance(checkpointing_steps, int):
|
if isinstance(checkpointing_steps, int):
|
||||||
if completed_steps % checkpointing_steps == 0:
|
if completed_steps % checkpointing_steps == 0:
|
||||||
accelerator.save_state(f"step_{completed_steps}")
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if completed_steps >= args.max_train_steps:
|
if completed_steps >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if args.checkpointing_steps == "epoch":
|
||||||
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
@ -879,9 +889,6 @@ def main():
|
|||||||
|
|
||||||
accelerator.log(log, step=completed_steps)
|
accelerator.log(log, step=completed_steps)
|
||||||
|
|
||||||
if args.checkpointing_steps == "epoch":
|
|
||||||
accelerator.save_state(f"epoch_{epoch}")
|
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
@ -890,6 +897,8 @@ def main():
|
|||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
|
json.dump({"eval_f1": eval_metric["f1"], "eval_exact": eval_metric["exact"]}, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model on summarization.
|
|||||||
# You can also adapt this script on your own summarization task. Pointers for this are left as comments.
|
# You can also adapt this script on your own summarization task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -602,7 +603,10 @@ def main():
|
|||||||
|
|
||||||
if isinstance(checkpointing_steps, int):
|
if isinstance(checkpointing_steps, int):
|
||||||
if completed_steps % checkpointing_steps == 0:
|
if completed_steps % checkpointing_steps == 0:
|
||||||
accelerator.save_state(f"step_{completed_steps}")
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if completed_steps >= args.max_train_steps:
|
if completed_steps >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@ -669,7 +673,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.checkpointing_steps == "epoch":
|
if args.checkpointing_steps == "epoch":
|
||||||
accelerator.save_state(f"epoch_{epoch}")
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
@ -679,6 +686,16 @@ def main():
|
|||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"eval_rouge1": result["rouge1"],
|
||||||
|
"eval_rouge2": result["rouge2"],
|
||||||
|
"eval_rougeL": result["rougeL"],
|
||||||
|
"eval_rougeLsum": result["rougeLsum"],
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
302
examples/pytorch/test_accelerate_examples.py
Normal file
302
examples/pytorch/test_accelerate_examples.py
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 HuggingFace Inc..
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
|
||||||
|
from transformers.utils import is_apex_available
|
||||||
|
|
||||||
|
|
||||||
|
SRC_DIRS = [
|
||||||
|
os.path.join(os.path.dirname(__file__), dirname)
|
||||||
|
for dirname in [
|
||||||
|
"text-generation",
|
||||||
|
"text-classification",
|
||||||
|
"token-classification",
|
||||||
|
"language-modeling",
|
||||||
|
"multiple-choice",
|
||||||
|
"question-answering",
|
||||||
|
"summarization",
|
||||||
|
"translation",
|
||||||
|
"image-classification",
|
||||||
|
"speech-recognition",
|
||||||
|
"audio-classification",
|
||||||
|
"speech-pretraining",
|
||||||
|
"image-pretraining",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
sys.path.extend(SRC_DIRS)
|
||||||
|
|
||||||
|
|
||||||
|
if SRC_DIRS is not None:
|
||||||
|
import run_clm_no_trainer
|
||||||
|
import run_glue_no_trainer
|
||||||
|
import run_mlm_no_trainer
|
||||||
|
import run_ner_no_trainer
|
||||||
|
import run_qa_no_trainer as run_squad_no_trainer
|
||||||
|
import run_summarization_no_trainer
|
||||||
|
import run_swag_no_trainer
|
||||||
|
import run_translation_no_trainer
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_setup_file():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-f")
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args.f
|
||||||
|
|
||||||
|
|
||||||
|
def get_results(output_dir):
|
||||||
|
results = {}
|
||||||
|
path = os.path.join(output_dir, "all_results.json")
|
||||||
|
if os.path.exists(path):
|
||||||
|
with open(path, "r") as f:
|
||||||
|
results = json.load(f)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"can't find {path}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def is_cuda_and_apex_available():
|
||||||
|
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
|
||||||
|
return is_using_cuda and is_apex_available()
|
||||||
|
|
||||||
|
|
||||||
|
class ExamplesTestsNoTrainer(TestCasePlus):
|
||||||
|
def test_run_glue_no_trainer(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_glue_no_trainer.py
|
||||||
|
--model_name_or_path distilbert-base-uncased
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
|
||||||
|
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=1
|
||||||
|
--learning_rate=1e-4
|
||||||
|
--seed=42
|
||||||
|
--checkpointing_steps epoch
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
if is_cuda_and_apex_available():
|
||||||
|
testargs.append("--fp16")
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_glue_no_trainer.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
|
def test_run_clm_no_trainer(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_clm_no_trainer.py
|
||||||
|
--model_name_or_path distilgpt2
|
||||||
|
--train_file ./tests/fixtures/sample_text.txt
|
||||||
|
--validation_file ./tests/fixtures/sample_text.txt
|
||||||
|
--block_size 128
|
||||||
|
--per_device_train_batch_size 5
|
||||||
|
--per_device_eval_batch_size 5
|
||||||
|
--num_train_epochs 2
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--checkpointing_steps epoch
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
|
||||||
|
return
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_clm_no_trainer.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertLess(result["perplexity"], 100)
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
|
def test_run_mlm_no_trainer(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_mlm_no_trainer.py
|
||||||
|
--model_name_or_path distilroberta-base
|
||||||
|
--train_file ./tests/fixtures/sample_text.txt
|
||||||
|
--validation_file ./tests/fixtures/sample_text.txt
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--num_train_epochs=1
|
||||||
|
--checkpointing_steps epoch
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_mlm_no_trainer.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertLess(result["perplexity"], 42)
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
|
def test_run_ner_no_trainer(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
|
||||||
|
epochs = 7 if get_gpu_count() > 1 else 2
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_ner_no_trainer.py
|
||||||
|
--model_name_or_path bert-base-uncased
|
||||||
|
--train_file tests/fixtures/tests_samples/conll/sample.json
|
||||||
|
--validation_file tests/fixtures/tests_samples/conll/sample.json
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--learning_rate=2e-4
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=2
|
||||||
|
--num_train_epochs={epochs}
|
||||||
|
--seed 7
|
||||||
|
--checkpointing_steps epoch
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_ner_no_trainer.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
|
||||||
|
self.assertLess(result["train_loss"], 0.5)
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
|
def test_run_squad_no_trainer(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_qa_no_trainer.py
|
||||||
|
--model_name_or_path bert-base-uncased
|
||||||
|
--version_2_with_negative=False
|
||||||
|
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
|
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--max_train_steps=10
|
||||||
|
--num_warmup_steps=2
|
||||||
|
--learning_rate=2e-4
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=1
|
||||||
|
--checkpointing_steps epoch
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_squad_no_trainer.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_f1"], 30)
|
||||||
|
self.assertGreaterEqual(result["eval_exact"], 30)
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
|
def test_run_swag_no_trainer(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_swag_no_trainer.py
|
||||||
|
--model_name_or_path bert-base-uncased
|
||||||
|
--train_file tests/fixtures/tests_samples/swag/sample.json
|
||||||
|
--validation_file tests/fixtures/tests_samples/swag/sample.json
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--max_train_steps=20
|
||||||
|
--num_warmup_steps=2
|
||||||
|
--learning_rate=2e-4
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=1
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_swag_no_trainer.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_run_summarization_no_trainer(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_summarization_no_trainer.py
|
||||||
|
--model_name_or_path t5-small
|
||||||
|
--train_file tests/fixtures/tests_samples/xsum/sample.json
|
||||||
|
--validation_file tests/fixtures/tests_samples/xsum/sample.json
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--max_train_steps=50
|
||||||
|
--num_warmup_steps=8
|
||||||
|
--learning_rate=2e-4
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=1
|
||||||
|
--checkpointing_steps epoch
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_summarization_no_trainer.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_rouge1"], 10)
|
||||||
|
self.assertGreaterEqual(result["eval_rouge2"], 2)
|
||||||
|
self.assertGreaterEqual(result["eval_rougeL"], 7)
|
||||||
|
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_run_translation_no_trainer(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_translation_no_trainer.py
|
||||||
|
--model_name_or_path sshleifer/student_marian_en_ro_6_1
|
||||||
|
--source_lang en
|
||||||
|
--target_lang ro
|
||||||
|
--train_file tests/fixtures/tests_samples/wmt16/sample.json
|
||||||
|
--validation_file tests/fixtures/tests_samples/wmt16/sample.json
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--max_train_steps=50
|
||||||
|
--num_warmup_steps=8
|
||||||
|
--learning_rate=3e-3
|
||||||
|
--per_device_train_batch_size=2
|
||||||
|
--per_device_eval_batch_size=1
|
||||||
|
--source_lang en_XX
|
||||||
|
--target_lang ro_RO
|
||||||
|
--checkpointing_steps epoch
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_translation_no_trainer.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["eval_bleu"], 30)
|
||||||
|
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
|
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Finetuning a 🤗 Transformers model for sequence classification on GLUE."""
|
""" Finetuning a 🤗 Transformers model for sequence classification on GLUE."""
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -150,7 +151,6 @@ def parse_args():
|
|||||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||||
)
|
)
|
||||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpointing_steps",
|
"--checkpointing_steps",
|
||||||
type=str,
|
type=str,
|
||||||
@ -488,7 +488,10 @@ def main():
|
|||||||
|
|
||||||
if isinstance(checkpointing_steps, int):
|
if isinstance(checkpointing_steps, int):
|
||||||
if completed_steps % checkpointing_steps == 0:
|
if completed_steps % checkpointing_steps == 0:
|
||||||
accelerator.save_state(f"step_{completed_steps}")
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if completed_steps >= args.max_train_steps:
|
if completed_steps >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@ -526,7 +529,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.checkpointing_steps == "epoch":
|
if args.checkpointing_steps == "epoch":
|
||||||
accelerator.save_state(f"epoch_{epoch}")
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
@ -557,6 +563,10 @@ def main():
|
|||||||
eval_metric = metric.compute()
|
eval_metric = metric.compute()
|
||||||
logger.info(f"mnli-mm: {eval_metric}")
|
logger.info(f"mnli-mm: {eval_metric}")
|
||||||
|
|
||||||
|
if args.output_dir is not None:
|
||||||
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
|
json.dump({"eval_accuracy": eval_metric["accuracy"]}, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -19,6 +19,7 @@ without using a Trainer.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -639,7 +640,10 @@ def main():
|
|||||||
|
|
||||||
if isinstance(checkpointing_steps, int):
|
if isinstance(checkpointing_steps, int):
|
||||||
if completed_steps % checkpointing_steps == 0:
|
if completed_steps % checkpointing_steps == 0:
|
||||||
accelerator.save_state(f"step_{completed_steps}")
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if completed_steps >= args.max_train_steps:
|
if completed_steps >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@ -662,7 +666,6 @@ def main():
|
|||||||
references=refs,
|
references=refs,
|
||||||
) # predictions and preferences are expected to be a nested list of labels, not label_ids
|
) # predictions and preferences are expected to be a nested list of labels, not label_ids
|
||||||
|
|
||||||
# eval_metric = metric.compute()
|
|
||||||
eval_metric = compute_metrics()
|
eval_metric = compute_metrics()
|
||||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||||
if args.with_tracking:
|
if args.with_tracking:
|
||||||
@ -686,7 +689,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.checkpointing_steps == "epoch":
|
if args.checkpointing_steps == "epoch":
|
||||||
accelerator.save_state(f"epoch_{epoch}")
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
@ -697,6 +703,9 @@ def main():
|
|||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
|
|
||||||
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
|
json.dump({"eval_accuracy": eval_metric["accuracy"], "train_loss": float(loss.cpu().detach().numpy())}, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -19,6 +19,7 @@ Fine-tuning a 🤗 Transformers model on text translation.
|
|||||||
# You can also adapt this script on your own text translation task. Pointers for this are left as comments.
|
# You can also adapt this script on your own text translation task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -586,7 +587,10 @@ def main():
|
|||||||
|
|
||||||
if isinstance(checkpointing_steps, int):
|
if isinstance(checkpointing_steps, int):
|
||||||
if completed_steps % checkpointing_steps == 0:
|
if completed_steps % checkpointing_steps == 0:
|
||||||
accelerator.save_state(f"step_{completed_steps}")
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if completed_steps >= args.max_train_steps:
|
if completed_steps >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@ -653,7 +657,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if args.checkpointing_steps == "epoch":
|
if args.checkpointing_steps == "epoch":
|
||||||
accelerator.save_state(f"epoch_{epoch}")
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
@ -663,6 +670,8 @@ def main():
|
|||||||
tokenizer.save_pretrained(args.output_dir)
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||||
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
|
json.dump({"eval_bleu": eval_metric["score"]}, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -466,6 +466,7 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
|
|||||||
# Example files are tested separately
|
# Example files are tested separately
|
||||||
elif f.startswith("examples/pytorch"):
|
elif f.startswith("examples/pytorch"):
|
||||||
test_files_to_run.append("examples/pytorch/test_pytorch_examples.py")
|
test_files_to_run.append("examples/pytorch/test_pytorch_examples.py")
|
||||||
|
test_files_to_run.append("examples/pytorch/test_accelerate_examples.py")
|
||||||
elif f.startswith("examples/flax"):
|
elif f.startswith("examples/flax"):
|
||||||
test_files_to_run.append("examples/flax/test_flax_examples.py")
|
test_files_to_run.append("examples/flax/test_flax_examples.py")
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user