mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[s2s] wmt download script use less ram (#6405)
This commit is contained in:
parent
7c6a085ebf
commit
f6cb0f806e
@ -4,44 +4,48 @@ import fire
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def download_wmt_dataset(src_lang, tgt_lang, dataset="wmt19", save_dir=None) -> None:
|
def download_wmt_dataset(src_lang="ro", tgt_lang="en", dataset="wmt16", save_dir=None) -> None:
|
||||||
"""Download a dataset using the nlp package and save it to the format expected by finetune.py
|
"""Download a dataset using the nlp package and save it to the format expected by finetune.py
|
||||||
Format of save_dir: train.source, train.target, val.source, val.target, test.source, test.target.
|
Format of save_dir: train.source, train.target, val.source, val.target, test.source, test.target.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src_lang: <str> source language
|
src_lang: <str> source language
|
||||||
tgt_lang: <str> target language
|
tgt_lang: <str> target language
|
||||||
dataset: <str> like wmt19 (if you don't know, try wmt19).
|
dataset: <str> wmt16, wmt17, etc. wmt16 is a good start as it's small. To get the full list run `import nlp; print([d.id for d in nlp.list_datasets() if "wmt" in d.id])`
|
||||||
save_dir: <str>, where to save the datasets, defaults to f'{dataset}-{src_lang}-{tgt_lang}'
|
save_dir: <str>, where to save the datasets, defaults to f'{dataset}-{src_lang}-{tgt_lang}'
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
>>> download_wmt_dataset('en', 'ru', dataset='wmt19') # saves to wmt19_en_ru
|
>>> download_wmt_dataset('ro', 'en', dataset='wmt16') # saves to wmt16-ro-en
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import nlp
|
import nlp
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
raise ImportError("run pip install nlp")
|
raise ImportError("run pip install nlp")
|
||||||
pair = f"{src_lang}-{tgt_lang}"
|
pair = f"{src_lang}-{tgt_lang}"
|
||||||
|
print(f"Converting {dataset}-{pair}")
|
||||||
ds = nlp.load_dataset(dataset, pair)
|
ds = nlp.load_dataset(dataset, pair)
|
||||||
if save_dir is None:
|
if save_dir is None:
|
||||||
save_dir = f"{dataset}-{pair}"
|
save_dir = f"{dataset}-{pair}"
|
||||||
save_dir = Path(save_dir)
|
save_dir = Path(save_dir)
|
||||||
save_dir.mkdir(exist_ok=True)
|
save_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
for split in tqdm(ds.keys()):
|
for split in ds.keys():
|
||||||
tr_list = list(ds[split])
|
print(f"Splitting {split} with {ds[split].num_rows} records")
|
||||||
data = [x["translation"] for x in tr_list]
|
|
||||||
src, tgt = [], []
|
# to save to val.source, val.target like summary datasets
|
||||||
for example in data:
|
fn = "val" if split == "validation" else split
|
||||||
src.append(example[src_lang])
|
src_path = save_dir.joinpath(f"{fn}.source")
|
||||||
tgt.append(example[tgt_lang])
|
tgt_path = save_dir.joinpath(f"{fn}.target")
|
||||||
if split == "validation":
|
src_fp = src_path.open("w+")
|
||||||
split = "val" # to save to val.source, val.target like summary datasets
|
tgt_fp = tgt_path.open("w+")
|
||||||
src_path = save_dir.joinpath(f"{split}.source")
|
|
||||||
src_path.open("w+").write("\n".join(src))
|
# reader is the bottleneck so writing one record at a time doesn't slow things down
|
||||||
tgt_path = save_dir.joinpath(f"{split}.target")
|
for x in tqdm(ds[split]):
|
||||||
tgt_path.open("w+").write("\n".join(tgt))
|
ex = x["translation"]
|
||||||
print(f"saved dataset to {save_dir}")
|
src_fp.write(ex[src_lang] + "\n")
|
||||||
|
tgt_fp.write(ex[tgt_lang] + "\n")
|
||||||
|
|
||||||
|
print(f"Saved {dataset} dataset to {save_dir}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user