[Examples, T5] Change newstest2013 to newstest2014 and clean up (#3817)

* Refactored use of newstest2013 to newstest2014. Fixed bug where argparse consumed first command line argument as model_size argument rather than using default model_size by forcing explicit --model_size flag inclusion

* More pythonic file handling through 'with' context

* COSMETIC - ran Black and isort

* Fixed reference to number of lines in newstest2014

* Fixed failing test. More pythonic file handling

* finish PR from tholiao

* remove outcommented lines

* make style

* make isort happy

Co-authored-by: Thomas Liao <tholiao@gmail.com>
This commit is contained in:
Patrick von Platen 2020-04-16 20:00:41 +02:00 committed by GitHub
parent d486795158
commit 80a1694514
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 27 deletions

View File

@ -9,17 +9,17 @@ evaluated on the WMT English-German dataset.
To be able to reproduce the authors' results on WMT English to German, you first need to download
the WMT14 en-de news datasets.
Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2013.en" and "newstest2013.de" under WMT'14 English-German data or download the dataset directly via:
Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2014.en" and "newstest2014.de" under WMT'14 English-German data or download the dataset directly via:
```bash
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.en > newstest2013.en
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.de > newstest2013.de
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en > newstest2014.en
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de > newstest2014.de
```
You should have 3000 sentence in each file. You can verify this by running:
You should have 2737 sentences in each file. You can verify this by running:
```bash
wc -l newstest2013.en # should give 3000
wc -l newstest2014.en # should give 2737
```
### Usage
@ -29,8 +29,8 @@ Let's check the longest and shortest sentence in our file to find reasonable dec
Get the longest and shortest sentence:
```bash
awk '{print NF}' newstest2013.en | sort -n | head -1 # shortest sentence has 1 word
awk '{print NF}' newstest2013.en | sort -n | tail -1 # longest sentence has 106 words
awk '{print NF}' newstest2014.en | sort -n | head -1 # shortest sentence has 2 word
awk '{print NF}' newstest2014.en | sort -n | tail -1 # longest sentence has 91 words
```
We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0.
@ -38,7 +38,7 @@ We decode with beam search `num_beams=4` as proposed in the paper. Also as is co
To create translation for each in dataset and get a final BLEU score, run:
```bash
python evaluate_wmt.py <path_to_newstest2013.en> newstest2013_de_translations.txt <path_to_newstest2013.de> newsstest2013_en_de_bleu.txt
python evaluate_wmt.py <path_to_newstest2014.en> newstest2014_de_translations.txt <path_to_newstest2014.de> newsstest2014_en_de_bleu.txt
```
the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system.

View File

@ -15,8 +15,6 @@ def chunks(lst, n):
def generate_translations(lns, output_file_path, model_size, batch_size, device):
output_file = Path(output_file_path).open("w")
model = T5ForConditionalGeneration.from_pretrained(model_size)
model.to(device)
@ -27,27 +25,29 @@ def generate_translations(lns, output_file_path, model_size, batch_size, device)
if task_specific_params is not None:
model.config.update(task_specific_params.get("translation_en_to_de", {}))
for batch in tqdm(list(chunks(lns, batch_size))):
batch = [model.config.prefix + text for text in batch]
with Path(output_file_path).open("w") as output_file:
for batch in tqdm(list(chunks(lns, batch_size))):
batch = [model.config.prefix + text for text in batch]
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True)
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True)
input_ids = dct["input_ids"].to(device)
attention_mask = dct["attention_mask"].to(device)
input_ids = dct["input_ids"].to(device)
attention_mask = dct["attention_mask"].to(device)
translations = model.generate(input_ids=input_ids, attention_mask=attention_mask)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations]
translations = model.generate(input_ids=input_ids, attention_mask=attention_mask)
dec = [
tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations
]
for hypothesis in dec:
output_file.write(hypothesis + "\n")
output_file.flush()
for hypothesis in dec:
output_file.write(hypothesis + "\n")
def calculate_bleu_score(output_lns, refs_lns, score_path):
bleu = corpus_bleu(output_lns, [refs_lns])
result = "BLEU score: {}".format(bleu.score)
score_file = Path(score_path).open("w")
score_file.write(result)
with Path(score_path).open("w") as score_file:
score_file.write(result)
def run_generate():
@ -59,13 +59,13 @@ def run_generate():
default="t5-base",
)
parser.add_argument(
"input_path", type=str, help="like wmt/newstest2013.en",
"input_path", type=str, help="like wmt/newstest2014.en",
)
parser.add_argument(
"output_path", type=str, help="where to save translation",
)
parser.add_argument(
"reference_path", type=str, help="like wmt/newstest2013.de",
"reference_path", type=str, help="like wmt/newstest2014.de",
)
parser.add_argument(
"score_path", type=str, help="where to save the bleu score",
@ -82,12 +82,19 @@ def run_generate():
dash_pattern = (" ##AT##-##AT## ", "-")
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()]
# Read input lines into python
with open(args.input_path, "r") as input_file:
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in input_file.readlines()]
generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device)
output_lns = [x.strip() for x in open(args.output_path).readlines()]
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()]
# Read generated lines into python
with open(args.output_path, "r") as output_file:
output_lns = [x.strip() for x in output_file.readlines()]
# Read reference lines into python
with open(args.reference_path, "r") as reference_file:
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in reference_file.readlines()]
calculate_bleu_score(output_lns, refs_lns, args.score_path)