mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
integrate ROUGE
This commit is contained in:
parent
076602bdc4
commit
ade3cdf5ad
@ -21,9 +21,6 @@
|
||||
# SOFTWARE.
|
||||
import copy
|
||||
import math
|
||||
import shutil
|
||||
import time
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -1082,11 +1079,6 @@ class Translator(object):
|
||||
|
||||
return translations
|
||||
|
||||
def _report_rouge(self, gold_path, can_path):
|
||||
self.logger.info("Calculating Rouge")
|
||||
results_dict = test_rouge(self.args.temp_dir, can_path, gold_path)
|
||||
return results_dict
|
||||
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
@ -1113,63 +1105,10 @@ def tile(x, count, dim=0):
|
||||
|
||||
|
||||
#
|
||||
# All things ROUGE. Uses `pyrouge` which is a hot mess.
|
||||
# Optimizer for training. We keep this here in case we want to add
|
||||
# a finetuning script.
|
||||
#
|
||||
|
||||
|
||||
def test_rouge(temp_dir, cand, ref):
|
||||
candidates = [line.strip() for line in open(cand, encoding="utf-8")]
|
||||
references = [line.strip() for line in open(ref, encoding="utf-8")]
|
||||
print(len(candidates))
|
||||
print(len(references))
|
||||
assert len(candidates) == len(references)
|
||||
|
||||
cnt = len(candidates)
|
||||
current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||
tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time))
|
||||
if not os.path.isdir(tmp_dir):
|
||||
os.mkdir(tmp_dir)
|
||||
os.mkdir(tmp_dir + "/candidate")
|
||||
os.mkdir(tmp_dir + "/reference")
|
||||
try:
|
||||
|
||||
for i in range(cnt):
|
||||
if len(references[i]) < 1:
|
||||
continue
|
||||
with open(
|
||||
tmp_dir + "/candidate/cand.{}.txt".format(i), "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(candidates[i])
|
||||
with open(
|
||||
tmp_dir + "/reference/ref.{}.txt".format(i), "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(references[i])
|
||||
r = pyrouge.Rouge155(temp_dir=temp_dir)
|
||||
r.model_dir = tmp_dir + "/reference/"
|
||||
r.system_dir = tmp_dir + "/candidate/"
|
||||
r.model_filename_pattern = "ref.#ID#.txt"
|
||||
r.system_filename_pattern = r"cand.(\d+).txt"
|
||||
rouge_results = r.convert_and_evaluate()
|
||||
print(rouge_results)
|
||||
results_dict = r.output_to_dict(rouge_results)
|
||||
finally:
|
||||
pass
|
||||
if os.path.isdir(tmp_dir):
|
||||
shutil.rmtree(tmp_dir)
|
||||
return results_dict
|
||||
|
||||
|
||||
def rouge_results_to_str(results_dict):
|
||||
return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format(
|
||||
results_dict["rouge_1_f_score"] * 100,
|
||||
results_dict["rouge_2_f_score"] * 100,
|
||||
results_dict["rouge_l_f_score"] * 100,
|
||||
results_dict["rouge_1_recall"] * 100,
|
||||
results_dict["rouge_2_recall"] * 100,
|
||||
results_dict["rouge_l_recall"] * 100,
|
||||
)
|
||||
|
||||
|
||||
class BertSumOptimizer(object):
|
||||
""" Specific optimizer for BertSum.
|
||||
|
||||
|
@ -41,6 +41,26 @@ def evaluate(args):
|
||||
"PAD": tokenizer.vocab["[PAD]"],
|
||||
}
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries = []
|
||||
generated_summaries = []
|
||||
|
||||
import rouge
|
||||
import nltk
|
||||
nltk.download('punkt')
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=['rouge-n', 'rouge-l'],
|
||||
max_n=2,
|
||||
limit_length=True,
|
||||
length_limit=args.beam_size,
|
||||
length_limit_type='words',
|
||||
apply_avg=True,
|
||||
apply_best=False,
|
||||
alpha=0.5, # Default F1_score
|
||||
weight_factor=1.2,
|
||||
stemming=True,
|
||||
)
|
||||
|
||||
# these (unused) arguments are defined to keep the compatibility
|
||||
# with the legacy code and will be deleted in a next iteration.
|
||||
args.result_path = ""
|
||||
@ -66,6 +86,16 @@ def evaluate(args):
|
||||
summaries = [format_summary(t) for t in translations]
|
||||
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries += batch.tgt_str
|
||||
generated_summaries += summaries
|
||||
|
||||
if args.compute_rouge:
|
||||
scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
|
||||
str_scores = format_rouge_scores(scores)
|
||||
save_rouge_scores(str_scores)
|
||||
print(str_scores)
|
||||
|
||||
|
||||
def format_summary(translation):
|
||||
""" Transforms the output of the `from_batch` function
|
||||
@ -86,6 +116,41 @@ def format_summary(translation):
|
||||
return summary
|
||||
|
||||
|
||||
def format_rouge_scores(scores):
|
||||
return """\n
|
||||
****** ROUGE SCORES ******
|
||||
|
||||
** ROUGE 1
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE 2
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE L
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}""".format(
|
||||
scores['rouge-1']['f'],
|
||||
scores['rouge-1']['p'],
|
||||
scores['rouge-1']['r'],
|
||||
scores['rouge-2']['f'],
|
||||
scores['rouge-2']['p'],
|
||||
scores['rouge-2']['r'],
|
||||
scores['rouge-l']['f'],
|
||||
scores['rouge-l']['p'],
|
||||
scores['rouge-l']['r'],
|
||||
)
|
||||
|
||||
|
||||
def save_rouge_scores(str_scores):
|
||||
with open("rouge_scores.txt", "w") as output:
|
||||
output.write(str_scores)
|
||||
|
||||
|
||||
def save_summaries(summaries, path, original_document_name):
|
||||
""" Write the summaries in fies that are prefixed by the original
|
||||
files' name with the `_summary` appended.
|
||||
@ -142,26 +207,27 @@ def collate(data, tokenizer, block_size):
|
||||
"""
|
||||
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
|
||||
names = [name for name, _, _ in data]
|
||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||
|
||||
encoded_text = [
|
||||
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
|
||||
]
|
||||
stories = torch.tensor(
|
||||
encoded_stories = torch.tensor(
|
||||
[
|
||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
|
||||
for story, _ in encoded_text
|
||||
]
|
||||
)
|
||||
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
|
||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||
|
||||
batch = Batch(
|
||||
document_names=names,
|
||||
batch_size=len(stories),
|
||||
src=stories,
|
||||
batch_size=len(encoded_stories),
|
||||
src=encoded_stories,
|
||||
segs=encoder_token_type_ids,
|
||||
mask_src=encoder_mask,
|
||||
tgt_str=[""] * len(stories),
|
||||
tgt_str=summaries,
|
||||
)
|
||||
|
||||
return batch
|
||||
@ -196,6 +262,13 @@ def main():
|
||||
required=False,
|
||||
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compute_rouge",
|
||||
default=False,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.",
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--visible_gpus",
|
||||
|
@ -11,4 +11,5 @@ sentencepiece
|
||||
# For XLM
|
||||
sacremoses
|
||||
# For ROUGE
|
||||
nltk
|
||||
py-rouge
|
||||
|
Loading…
Reference in New Issue
Block a user