mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 04:28:26 +06:00

* rebase to master * change tf to pytorch * change to pytorch * small fix * renaming * add gpu training possibility * renaming * improve README * incoorporate collins feedback * better Readme * better README.md
32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
import argparse
|
|
from pathlib import Path
|
|
|
|
import tensorflow_datasets as tfds
|
|
|
|
|
|
def main(input_path, reference_path, data_dir):
|
|
cnn_ds = tfds.load("cnn_dailymail", split="test", shuffle_files=False, data_dir=data_dir)
|
|
cnn_ds_iter = tfds.as_numpy(cnn_ds)
|
|
|
|
test_articles_file = Path(input_path).open("w")
|
|
test_summaries_file = Path(reference_path).open("w")
|
|
|
|
for example in cnn_ds_iter:
|
|
test_articles_file.write(example["article"].decode("utf-8") + "\n")
|
|
test_articles_file.flush()
|
|
test_summaries_file.write(example["highlights"].decode("utf-8").replace("\n", " ") + "\n")
|
|
test_summaries_file.flush()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("input_path", type=str, help="where to save the articles input data")
|
|
parser.add_argument(
|
|
"reference_path", type=str, help="where to save the reference summaries",
|
|
)
|
|
parser.add_argument(
|
|
"--data_dir", type=str, default="~/tensorflow_datasets", help="where to save the tensorflow datasets.",
|
|
)
|
|
args = parser.parse_args()
|
|
main(args.input_path, args.reference_path, args.data_dir)
|