Update Seq2Seq QA example script to use SQuAD metric. (#14335)

* Update postporcessing accordingly to use SQuAD metric.

* Update assets accordingly based on SQuAD metrics.

* Fix function naming error.
This commit is contained in:
karthikrangasai 2021-11-09 18:34:23 +05:30 committed by GitHub
parent be4a6c64dc
commit 4f24058c58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 201 additions and 43 deletions

View File

@ -25,22 +25,20 @@ from dataclasses import dataclass, field
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import datasets import datasets
import nltk
import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset, load_metric
import transformers import transformers
from trainer_seq2seq_qa import QuestionAnsweringSeq2SeqTrainer
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
HfArgumentParser, HfArgumentParser,
Seq2SeqTrainer,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import EvalPrediction, get_last_checkpoint from transformers.trainer_utils import EvalLoopOutput, EvalPrediction, get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -411,7 +409,7 @@ def main():
) )
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
def preprocess_sqaud_batch( def preprocess_squad_batch(
examples, examples,
question_column: str, question_column: str,
context_column: str, context_column: str,
@ -422,14 +420,14 @@ def main():
answers = examples[answer_column] answers = examples[answer_column]
def generate_input(_question, _context): def generate_input(_question, _context):
return " ".join(["question:", _question, "context:", _context]) return " ".join(["question:", _question.lstrip(), "context:", _context.lstrip()])
inputs = [generate_input(question, context) for question, context in zip(questions, contexts)] inputs = [generate_input(question, context) for question, context in zip(questions, contexts)]
targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers] targets = [answer["text"][0] if len(answer["text"]) > 0 else "" for answer in answers]
return inputs, targets return inputs, targets
def preprocess_function(examples): def preprocess_function(examples):
inputs, targets = preprocess_sqaud_batch(examples, question_column, context_column, answer_column) inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(inputs, max_length=max_seq_length, padding=padding, truncation=True) model_inputs = tokenizer(inputs, max_length=max_seq_length, padding=padding, truncation=True)
# Setup the tokenizer for targets # Setup the tokenizer for targets
@ -446,6 +444,45 @@ def main():
model_inputs["labels"] = labels["input_ids"] model_inputs["labels"] = labels["input_ids"]
return model_inputs return model_inputs
# Validation preprocessing
def preprocess_validation_function(examples):
inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column)
model_inputs = tokenizer(
inputs,
max_length=max_seq_length,
padding=padding,
truncation=True,
return_overflowing_tokens=True,
return_offsets_mapping=True,
)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=max_answer_length, padding=padding, truncation=True)
# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = model_inputs.pop("overflow_to_sample_mapping")
# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
# corresponding example_id and we will store the offset mappings.
model_inputs["example_id"] = []
for i in range(len(model_inputs["input_ids"])):
# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
model_inputs["example_id"].append(examples["id"][sample_index])
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length" and data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
if training_args.do_train: if training_args.do_train:
if "train" not in raw_datasets: if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
@ -477,7 +514,7 @@ def main():
# Validation Feature Creation # Validation Feature Creation
with training_args.main_process_first(desc="validation dataset map pre-processing"): with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_examples.map( eval_dataset = eval_examples.map(
preprocess_function, preprocess_validation_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
@ -498,7 +535,7 @@ def main():
# Predict Feature Creation # Predict Feature Creation
with training_args.main_process_first(desc="prediction dataset map pre-processing"): with training_args.main_process_first(desc="prediction dataset map pre-processing"):
predict_dataset = predict_examples.map( predict_dataset = predict_examples.map(
preprocess_function, preprocess_validation_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
@ -518,50 +555,53 @@ def main():
pad_to_multiple_of=8 if training_args.fp16 else None, pad_to_multiple_of=8 if training_args.fp16 else None,
) )
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
# Post-processing: # Post-processing:
def postprocess_text(preds, labels): def post_processing_function(
preds = [" ".join(pred) for pred in preds] examples: datasets.Dataset, features: datasets.Dataset, outputs: EvalLoopOutput, stage="eval"
preds = [pred.strip() for pred in preds] ):
labels = [label.strip() for label in labels] # Decode the predicted tokens.
preds = outputs.predictions
# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
return preds, labels
metric = load_metric("rouge")
def compute_metrics(eval_preds: EvalPrediction):
preds, labels = eval_preds
if isinstance(preds, tuple): if isinstance(preds, tuple):
preds = preds[0] preds = preds[0]
decoded_preds = [tokenizer.batch_decode(pred, skip_special_tokens=True) for pred in preds] decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if data_args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing # Build a map example to its corresponding features.
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
# Extract a few results from ROUGE predictions = {}
result = {key: value.mid.fmeasure * 100 for key, value in result.items()} # Let's loop over all the examples!
for example_index, example in enumerate(examples):
# This is the index of the feature associated to the current example.
feature_index = feature_per_example[example_index]
predictions[example["id"]] = decoded_preds[feature_index]
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] # Format the result to the format the metric expects.
result["gen_len"] = np.mean(prediction_lens) if data_args.version_2_with_negative:
result = {k: round(v, 4) for k, v in result.items()} formatted_predictions = [
return result {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
# Initialize our Trainer # Initialize our Trainer
trainer = Seq2SeqTrainer( trainer = QuestionAnsweringSeq2SeqTrainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=train_dataset if training_args.do_train else None, train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
eval_examples=eval_examples if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
post_process_function=post_processing_function,
) )
# Training # Training

View File

@ -0,0 +1,120 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A subclass of `Trainer` specific to Question-Answering tasks
"""
from typing import Dict, List, Optional
from torch.utils.data import Dataset
from transformers import Seq2SeqTrainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
class QuestionAnsweringSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
super().__init__(*args, **kwargs)
self.eval_examples = eval_examples
self.post_process_function = post_process_function
# def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
eval_examples=None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
max_length: Optional[int] = None,
num_beams: Optional[int] = None,
) -> Dict[str, float]:
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
eval_dataloader = self.get_eval_dataloader(eval_dataset)
eval_examples = self.eval_examples if eval_examples is None else eval_examples
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
if self.post_process_function is not None and self.compute_metrics is not None:
eval_preds = self.post_process_function(eval_examples, eval_dataset, output)
metrics = self.compute_metrics(eval_preds)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
self.log(metrics)
else:
metrics = {}
if self.args.tpu_metrics_debug or self.args.debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
return metrics
def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
predict_dataloader = self.get_test_dataloader(predict_dataset)
# Temporarily disable metric computation, we will do it in the loop here.
compute_metrics = self.compute_metrics
self.compute_metrics = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
try:
output = eval_loop(
predict_dataloader,
description="Prediction",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if compute_metrics is None else None,
ignore_keys=ignore_keys,
)
finally:
self.compute_metrics = compute_metrics
if self.post_process_function is None or self.compute_metrics is None:
return output
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
metrics = self.compute_metrics(predictions)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)

View File

@ -274,10 +274,8 @@ class ExamplesTests(TestCasePlus):
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_squad_seq2seq.main() run_squad_seq2seq.main()
result = get_results(tmp_dir) result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10) self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_rouge2"], 10) self.assertGreaterEqual(result["eval_exact"], 30)
self.assertGreaterEqual(result["eval_rougeL"], 10)
self.assertGreaterEqual(result["eval_rougeLsum"], 10)
def test_run_swag(self): def test_run_swag(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)