mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Tensorflow QA example (#12252)
* New Tensorflow QA example! * Style pass * Updating README.md for the new example * flake8 fixes * Update examples/tensorflow/question-answering/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
4e9a6796c7
commit
e3cb7a0b60
@ -1,5 +1,5 @@
|
||||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,21 +14,42 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
## SQuAD with the Tensorflow Trainer
|
||||
# Question answering example
|
||||
|
||||
```bash
|
||||
python run_tf_squad.py \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--output_dir model \
|
||||
--max_seq_length 384 \
|
||||
--num_train_epochs 2 \
|
||||
--per_gpu_train_batch_size 8 \
|
||||
--per_gpu_eval_batch_size 16 \
|
||||
--do_train \
|
||||
--logging_dir logs \
|
||||
--logging_steps 10 \
|
||||
--learning_rate 3e-5 \
|
||||
--doc_stride 128
|
||||
This folder contains the `run_qa.py` script, demonstrating *question answering* with the 🤗 Transformers library.
|
||||
For straightforward use-cases you may be able to use this script without modification, although we have also
|
||||
included comments in the code to indicate areas that you may need to adapt to your own projects.
|
||||
|
||||
### Usage notes
|
||||
Note that when contexts are long they may be split into multiple training cases, not all of which may contain
|
||||
the answer span.
|
||||
|
||||
As-is, the example script will train on SQuAD or any other question-answering dataset formatted the same way, and can handle user
|
||||
inputs as well.
|
||||
|
||||
### Multi-GPU and TPU usage
|
||||
|
||||
By default, the script uses a `MirroredStrategy` and will use multiple GPUs effectively if they are available. TPUs
|
||||
can also be used by passing the name of the TPU resource with the `--tpu` argument. There are some issues surrounding
|
||||
these strategies and our models right now, which are most likely to appear in the evaluation/prediction steps. We're
|
||||
actively working on better support for multi-GPU and TPU training in TF, but if you encounter problems a quick
|
||||
workaround is to train in the multi-GPU or TPU context and then perform predictions outside of it.
|
||||
|
||||
### Memory usage and data loading
|
||||
|
||||
One thing to note is that all data is loaded into memory in this script. Most question answering datasets are small
|
||||
enough that this is not an issue, but if you have a very large dataset you will need to modify the script to handle
|
||||
data streaming. This is particularly challenging for TPUs, given the stricter requirements and the sheer volume of data
|
||||
required to keep them fed. A full explanation of all the possible pitfalls is a bit beyond this example script and
|
||||
README, but for more information you can see the 'Input Datasets' section of
|
||||
[this document](https://www.tensorflow.org/guide/tpu).
|
||||
|
||||
### Example command
|
||||
```
|
||||
python run_qa.py \
|
||||
--model_name_or_path distilbert-base-cased \
|
||||
--output_dir output \
|
||||
--dataset_name squad \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
```
|
||||
|
||||
For the moment evaluation is not available in the Tensorflow Trainer only the training.
|
||||
|
694
examples/tensorflow/question-answering/run_qa.py
Executable file
694
examples/tensorflow/question-answering/run_qa.py
Executable file
@ -0,0 +1,694 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for question answering.
|
||||
"""
|
||||
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
from datasets import load_dataset, load_metric
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerFast,
|
||||
TFAutoModelForQuestionAnswering,
|
||||
TFTrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import CONFIG_NAME, TF2_WEIGHTS_NAME
|
||||
from transformers.utils import check_min_version
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.7.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# region Arguments
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input test data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=384,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch (which can "
|
||||
"be faster on GPU but will be slower on TPU)."
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
version_2_with_negative: bool = field(
|
||||
default=False, metadata={"help": "If true, some of the examples do not have an answer."}
|
||||
)
|
||||
null_score_diff_threshold: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "The threshold used to select the null answer: if the best answer has a score that is less than "
|
||||
"the score of the null answer minus this threshold, the null answer is selected for this example. "
|
||||
"Only useful when `version_2_with_negative=True`."
|
||||
},
|
||||
)
|
||||
doc_stride: int = field(
|
||||
default=128,
|
||||
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
|
||||
)
|
||||
n_best_size: int = field(
|
||||
default=20,
|
||||
metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
|
||||
)
|
||||
max_answer_length: int = field(
|
||||
default=30,
|
||||
metadata={
|
||||
"help": "The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.dataset_name is None
|
||||
and self.train_file is None
|
||||
and self.validation_file is None
|
||||
and self.test_file is None
|
||||
):
|
||||
raise ValueError("Need either a dataset name or a training/validation file/test_file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
if self.test_file is not None:
|
||||
extension = self.test_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Helper classes
|
||||
class SavePretrainedCallback(tf.keras.callbacks.Callback):
|
||||
# Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
|
||||
# metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
|
||||
# that saves the model with this method after each epoch.
|
||||
def __init__(self, output_dir, **kwargs):
|
||||
super().__init__()
|
||||
self.output_dir = output_dir
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
self.model.save_pretrained(self.output_dir)
|
||||
|
||||
|
||||
def convert_dataset_for_tensorflow(
|
||||
dataset, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=True
|
||||
):
|
||||
"""Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches
|
||||
to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former
|
||||
is most useful when training on TPU, as a new graph compilation is required for each sequence length.
|
||||
"""
|
||||
|
||||
def densify_ragged_batch(features, label=None):
|
||||
features = {
|
||||
feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) if feature in tensor_keys else ragged_tensor
|
||||
for feature, ragged_tensor in features.items()
|
||||
}
|
||||
if label is None:
|
||||
return features
|
||||
else:
|
||||
return features, label
|
||||
|
||||
tensor_keys = ["attention_mask", "input_ids"]
|
||||
label_keys = ["start_positions", "end_positions"]
|
||||
if dataset_mode == "variable_batch":
|
||||
batch_shape = {key: None for key in tensor_keys}
|
||||
data = {key: tf.ragged.constant(dataset[key]) for key in tensor_keys}
|
||||
elif dataset_mode == "constant_batch":
|
||||
data = {key: tf.ragged.constant(dataset[key]) for key in tensor_keys}
|
||||
batch_shape = {
|
||||
key: tf.concat(([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0)
|
||||
for key, ragged_tensor in data.items()
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unknown dataset mode!")
|
||||
|
||||
if all([key in dataset.features for key in label_keys]):
|
||||
for key in label_keys:
|
||||
data[key] = tf.convert_to_tensor(dataset[key])
|
||||
dummy_labels = tf.zeros_like(dataset[key])
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices((data, dummy_labels))
|
||||
else:
|
||||
tf_dataset = tf.data.Dataset.from_tensor_slices(data)
|
||||
if shuffle:
|
||||
tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
|
||||
tf_dataset = tf_dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder).map(densify_ragged_batch)
|
||||
return tf_dataset
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def main():
|
||||
# region Argument parsing
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
output_dir = Path(training_args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
# endregion
|
||||
|
||||
# region Checkpoints
|
||||
checkpoint = None
|
||||
if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir:
|
||||
if (output_dir / CONFIG_NAME).is_file() and (output_dir / TF2_WEIGHTS_NAME).is_file():
|
||||
checkpoint = output_dir
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training from checkpoint in {training_args.output_dir}. To avoid this"
|
||||
" behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to continue regardless."
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region Logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
|
||||
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if training_args.should_log:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
# endregion
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# region Load Data
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
||||
else:
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.validation_file.split(".")[-1]
|
||||
if data_args.test_file is not None:
|
||||
data_files["test"] = data_args.test_file
|
||||
extension = data_args.test_file.split(".")[-1]
|
||||
datasets = load_dataset(extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir)
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
# endregion
|
||||
|
||||
# region Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=True,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region Tokenizer check: this script requires a fast tokenizer.
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
raise ValueError(
|
||||
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
|
||||
"at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this "
|
||||
"requirement"
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region Preprocessing the datasets
|
||||
# Preprocessing is slightly different for training and evaluation.
|
||||
if training_args.do_train:
|
||||
column_names = datasets["train"].column_names
|
||||
elif training_args.do_eval:
|
||||
column_names = datasets["validation"].column_names
|
||||
else:
|
||||
column_names = datasets["test"].column_names
|
||||
question_column_name = "question" if "question" in column_names else column_names[0]
|
||||
context_column_name = "context" if "context" in column_names else column_names[1]
|
||||
answer_column_name = "answers" if "answers" in column_names else column_names[2]
|
||||
|
||||
# Padding side determines if we do (question|context) or (context|question).
|
||||
pad_on_right = tokenizer.padding_side == "right"
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warning(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
# Training preprocessing
|
||||
def prepare_train_features(examples):
|
||||
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
||||
# in one example possible giving several features when a context is long, each of those features having a
|
||||
# context that overlaps a bit the context of the previous feature.
|
||||
tokenized_examples = tokenizer(
|
||||
examples[question_column_name if pad_on_right else context_column_name],
|
||||
examples[context_column_name if pad_on_right else question_column_name],
|
||||
truncation="only_second" if pad_on_right else "only_first",
|
||||
max_length=max_seq_length,
|
||||
stride=data_args.doc_stride,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
padding="max_length" if data_args.pad_to_max_length else False,
|
||||
)
|
||||
|
||||
# Since one example might give us several features if it has a long context, we need a map from a feature to
|
||||
# its corresponding example. This key gives us just that.
|
||||
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
||||
# The offset mappings will give us a map from token to character position in the original context. This will
|
||||
# help us compute the start_positions and end_positions.
|
||||
offset_mapping = tokenized_examples.pop("offset_mapping")
|
||||
|
||||
# Let's label those examples!
|
||||
tokenized_examples["start_positions"] = []
|
||||
tokenized_examples["end_positions"] = []
|
||||
|
||||
for i, offsets in enumerate(offset_mapping):
|
||||
# We will label impossible answers with the index of the CLS token.
|
||||
input_ids = tokenized_examples["input_ids"][i]
|
||||
cls_index = input_ids.index(tokenizer.cls_token_id)
|
||||
|
||||
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
|
||||
sequence_ids = tokenized_examples.sequence_ids(i)
|
||||
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = sample_mapping[i]
|
||||
answers = examples[answer_column_name][sample_index]
|
||||
# If no answers are given, set the cls_index as answer.
|
||||
if len(answers["answer_start"]) == 0:
|
||||
tokenized_examples["start_positions"].append(cls_index)
|
||||
tokenized_examples["end_positions"].append(cls_index)
|
||||
else:
|
||||
# Start/end character index of the answer in the text.
|
||||
start_char = answers["answer_start"][0]
|
||||
end_char = start_char + len(answers["text"][0])
|
||||
|
||||
# Start token index of the current span in the text.
|
||||
token_start_index = 0
|
||||
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
|
||||
token_start_index += 1
|
||||
|
||||
# End token index of the current span in the text.
|
||||
token_end_index = len(input_ids) - 1
|
||||
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
|
||||
token_end_index -= 1
|
||||
|
||||
# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
|
||||
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
|
||||
tokenized_examples["start_positions"].append(cls_index)
|
||||
tokenized_examples["end_positions"].append(cls_index)
|
||||
else:
|
||||
# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
|
||||
# Note: we could go after the last offset if the answer is the last word (edge case).
|
||||
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
|
||||
token_start_index += 1
|
||||
tokenized_examples["start_positions"].append(token_start_index - 1)
|
||||
while offsets[token_end_index][1] >= end_char:
|
||||
token_end_index -= 1
|
||||
tokenized_examples["end_positions"].append(token_end_index + 1)
|
||||
|
||||
return tokenized_examples
|
||||
|
||||
processed_datasets = dict()
|
||||
if training_args.do_train:
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
# We will select sample from whole data if agument is specified
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
# Create train feature from dataset
|
||||
train_dataset = train_dataset.map(
|
||||
prepare_train_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_train_samples is not None:
|
||||
# Number of samples might increase during Feature Creation, We select only specified max samples
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
processed_datasets["train"] = train_dataset
|
||||
|
||||
# Validation preprocessing
|
||||
def prepare_validation_features(examples):
|
||||
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
||||
# in one example possible giving several features when a context is long, each of those features having a
|
||||
# context that overlaps a bit the context of the previous feature.
|
||||
tokenized_examples = tokenizer(
|
||||
examples[question_column_name if pad_on_right else context_column_name],
|
||||
examples[context_column_name if pad_on_right else question_column_name],
|
||||
truncation="only_second" if pad_on_right else "only_first",
|
||||
max_length=max_seq_length,
|
||||
stride=data_args.doc_stride,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
padding="max_length" if data_args.pad_to_max_length else False,
|
||||
)
|
||||
|
||||
# Since one example might give us several features if it has a long context, we need a map from a feature to
|
||||
# its corresponding example. This key gives us just that.
|
||||
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
|
||||
|
||||
# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
|
||||
# corresponding example_id and we will store the offset mappings.
|
||||
tokenized_examples["example_id"] = []
|
||||
|
||||
for i in range(len(tokenized_examples["input_ids"])):
|
||||
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
|
||||
sequence_ids = tokenized_examples.sequence_ids(i)
|
||||
context_index = 1 if pad_on_right else 0
|
||||
|
||||
# One example can give several spans, this is the index of the example containing this span of text.
|
||||
sample_index = sample_mapping[i]
|
||||
tokenized_examples["example_id"].append(examples["id"][sample_index])
|
||||
|
||||
# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
|
||||
# position is part of the context or not.
|
||||
tokenized_examples["offset_mapping"][i] = [
|
||||
(o if sequence_ids[k] == context_index else None)
|
||||
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
|
||||
]
|
||||
|
||||
return tokenized_examples
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_examples = datasets["validation"]
|
||||
if data_args.max_eval_samples is not None:
|
||||
# We will select sample from whole data
|
||||
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
|
||||
# Validation Feature Creation
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_eval_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
processed_datasets["validation"] = eval_dataset
|
||||
|
||||
if training_args.do_predict:
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
predict_examples = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
# We will select sample from whole data
|
||||
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
|
||||
# Predict Feature Creation
|
||||
predict_dataset = predict_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_predict_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
processed_datasets["test"] = predict_dataset
|
||||
# endregion
|
||||
|
||||
# region Metrics and Post-processing:
|
||||
def post_processing_function(examples, features, predictions, stage="eval"):
|
||||
# Post-processing: we match the start logits and end logits to answers in the original context.
|
||||
predictions = postprocess_qa_predictions(
|
||||
examples=examples,
|
||||
features=features,
|
||||
predictions=predictions,
|
||||
version_2_with_negative=data_args.version_2_with_negative,
|
||||
n_best_size=data_args.n_best_size,
|
||||
max_answer_length=data_args.max_answer_length,
|
||||
null_score_diff_threshold=data_args.null_score_diff_threshold,
|
||||
output_dir=training_args.output_dir,
|
||||
prefix=stage,
|
||||
)
|
||||
# Format the result to the format the metric expects.
|
||||
if data_args.version_2_with_negative:
|
||||
formatted_predictions = [
|
||||
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
|
||||
]
|
||||
else:
|
||||
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
|
||||
|
||||
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
|
||||
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
|
||||
|
||||
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
|
||||
|
||||
def compute_metrics(p: EvalPrediction):
|
||||
return metric.compute(predictions=p.predictions, references=p.label_ids)
|
||||
|
||||
# endregion
|
||||
|
||||
with training_args.strategy.scope():
|
||||
# region Load model
|
||||
if checkpoint is None:
|
||||
model_path = model_args.model_name_or_path
|
||||
else:
|
||||
model_path = checkpoint
|
||||
model = TFAutoModelForQuestionAnswering.from_pretrained(
|
||||
model_path,
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
optimizer = tf.keras.optimizers.Adam(
|
||||
learning_rate=training_args.learning_rate,
|
||||
beta_1=training_args.adam_beta1,
|
||||
beta_2=training_args.adam_beta2,
|
||||
epsilon=training_args.adam_epsilon,
|
||||
clipnorm=training_args.max_grad_norm,
|
||||
)
|
||||
|
||||
def dummy_loss(y_true, y_pred):
|
||||
return tf.reduce_mean(y_pred)
|
||||
|
||||
losses = {"loss": dummy_loss}
|
||||
model.compile(optimizer=optimizer, loss=losses)
|
||||
# endregion
|
||||
|
||||
# region Training
|
||||
if training_args.do_train:
|
||||
# Make a tf.data.Dataset for this
|
||||
if isinstance(training_args.strategy, tf.distribute.TPUStrategy) or data_args.pad_to_max_length:
|
||||
logger.info("Padding all batches to max length because argument was set or we're on TPU.")
|
||||
dataset_mode = "constant_batch"
|
||||
else:
|
||||
dataset_mode = "variable_batch"
|
||||
training_dataset = convert_dataset_for_tensorflow(
|
||||
processed_datasets["train"],
|
||||
batch_size=training_args.per_device_train_batch_size,
|
||||
dataset_mode=dataset_mode,
|
||||
drop_remainder=True,
|
||||
shuffle=True,
|
||||
)
|
||||
model.fit(training_dataset, epochs=int(training_args.num_train_epochs))
|
||||
# endregion
|
||||
|
||||
# region Evaluation
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluation ***")
|
||||
eval_inputs = {
|
||||
"input_ids": tf.ragged.constant(processed_datasets["validation"]["input_ids"]).to_tensor(),
|
||||
"attention_mask": tf.ragged.constant(processed_datasets["validation"]["attention_mask"]).to_tensor(),
|
||||
}
|
||||
eval_predictions = model.predict(eval_inputs)
|
||||
|
||||
post_processed_eval = post_processing_function(
|
||||
datasets["validation"],
|
||||
processed_datasets["validation"],
|
||||
(eval_predictions.start_logits, eval_predictions.end_logits),
|
||||
)
|
||||
metrics = compute_metrics(post_processed_eval)
|
||||
logging.info("Evaluation metrics:")
|
||||
for metric, value in metrics.items():
|
||||
logging.info(f"{metric}: {value:.3f}")
|
||||
# endregion
|
||||
|
||||
# region Prediction
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
predict_inputs = {
|
||||
"input_ids": tf.ragged.constant(processed_datasets["test"]["input_ids"]).to_tensor(),
|
||||
"attention_mask": tf.ragged.constant(processed_datasets["test"]["attention_mask"]).to_tensor(),
|
||||
}
|
||||
test_predictions = model.predict(predict_inputs)
|
||||
post_processed_test = post_processing_function(
|
||||
datasets["test"],
|
||||
processed_datasets["test"],
|
||||
(test_predictions.start_logits, test_predictions.end_logits),
|
||||
)
|
||||
metrics = compute_metrics(post_processed_test)
|
||||
|
||||
logging.info("Test metrics:")
|
||||
for metric, value in metrics.items():
|
||||
logging.info(f"{metric}: {value:.3f}")
|
||||
# endregion
|
||||
|
||||
if training_args.push_to_hub:
|
||||
model.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,255 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Fine-tuning the library models for question-answering."""
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
TFAutoModelForQuestionAnswering,
|
||||
TFTrainer,
|
||||
TFTrainingArguments,
|
||||
squad_convert_examples_to_features,
|
||||
)
|
||||
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor
|
||||
from transformers.utils import logging as hf_logging
|
||||
|
||||
|
||||
hf_logging.set_verbosity_info()
|
||||
hf_logging.enable_default_handler()
|
||||
hf_logging.enable_explicit_format()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
|
||||
# If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
|
||||
# or just modify its tokenizer_config.json.
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
data_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
|
||||
)
|
||||
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
|
||||
max_seq_length: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
doc_stride: int = field(
|
||||
default=128,
|
||||
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
|
||||
)
|
||||
max_query_length: int = field(
|
||||
default=64,
|
||||
metadata={
|
||||
"help": "The maximum number of tokens for the question. Questions longer than this will "
|
||||
"be truncated to this length."
|
||||
},
|
||||
)
|
||||
max_answer_length: int = field(
|
||||
default=30,
|
||||
metadata={
|
||||
"help": "The maximum length of an answer that can be generated. This is needed because the start "
|
||||
"and end predictions are not conditioned on one another."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
version_2_with_negative: bool = field(
|
||||
default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
|
||||
)
|
||||
null_score_diff_threshold: float = field(
|
||||
default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
|
||||
)
|
||||
n_best_size: int = field(
|
||||
default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
|
||||
)
|
||||
lang_id: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(
|
||||
f"n_replicas: {training_args.n_replicas}, distributed training: {bool(training_args.n_replicas > 1)}, "
|
||||
f"16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Prepare Question-Answering task
|
||||
# Load pretrained model and tokenizer
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast,
|
||||
)
|
||||
|
||||
with training_args.strategy.scope():
|
||||
model = TFAutoModelForQuestionAnswering.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_pt=bool(".bin" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
if data_args.use_tfds:
|
||||
if data_args.version_2_with_negative:
|
||||
logger.warning("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
|
||||
|
||||
try:
|
||||
import tensorflow_datasets as tfds
|
||||
except ImportError:
|
||||
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
||||
|
||||
tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
|
||||
train_examples = (
|
||||
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_examples = (
|
||||
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=True)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
else:
|
||||
processor = SquadV2Processor() if data_args.version_2_with_negative else SquadV1Processor()
|
||||
train_examples = processor.get_train_examples(data_args.data_dir) if training_args.do_train else None
|
||||
eval_examples = processor.get_dev_examples(data_args.data_dir) if training_args.do_eval else None
|
||||
|
||||
train_dataset = (
|
||||
squad_convert_examples_to_features(
|
||||
examples=train_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
doc_stride=data_args.doc_stride,
|
||||
max_query_length=data_args.max_query_length,
|
||||
is_training=True,
|
||||
return_dataset="tf",
|
||||
)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
|
||||
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
|
||||
|
||||
eval_dataset = (
|
||||
squad_convert_examples_to_features(
|
||||
examples=eval_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=data_args.max_seq_length,
|
||||
doc_stride=data_args.doc_stride,
|
||||
max_query_length=data_args.max_query_length,
|
||||
is_training=False,
|
||||
return_dataset="tf",
|
||||
)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
|
||||
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = TFTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
425
examples/tensorflow/question-answering/utils_qa.py
Normal file
425
examples/tensorflow/question-answering/utils_qa.py
Normal file
@ -0,0 +1,425 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Post-processing utilities for question answering.
|
||||
"""
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def postprocess_qa_predictions(
|
||||
examples,
|
||||
features,
|
||||
predictions: Tuple[np.ndarray, np.ndarray],
|
||||
version_2_with_negative: bool = False,
|
||||
n_best_size: int = 20,
|
||||
max_answer_length: int = 30,
|
||||
null_score_diff_threshold: float = 0.0,
|
||||
output_dir: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
|
||||
original contexts. This is the base postprocessing functions for models that only return start and end logits.
|
||||
|
||||
Args:
|
||||
examples: The non-preprocessed dataset (see the main script for more information).
|
||||
features: The processed dataset (see the main script for more information).
|
||||
predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
|
||||
The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
|
||||
first dimension must match the number of elements of :obj:`features`.
|
||||
version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the underlying dataset contains examples with no answers.
|
||||
n_best_size (:obj:`int`, `optional`, defaults to 20):
|
||||
The total number of n-best predictions to generate when looking for an answer.
|
||||
max_answer_length (:obj:`int`, `optional`, defaults to 30):
|
||||
The maximum length of an answer that can be generated. This is needed because the start and end predictions
|
||||
are not conditioned on one another.
|
||||
null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
|
||||
The threshold used to select the null answer: if the best answer has a score that is less than the score of
|
||||
the null answer minus this threshold, the null answer is selected for this example (note that the score of
|
||||
the null answer for an example giving several features is the minimum of the scores for the null answer on
|
||||
each feature: all features must be aligned on the fact they `want` to predict a null answer).
|
||||
|
||||
Only useful when :obj:`version_2_with_negative` is :obj:`True`.
|
||||
output_dir (:obj:`str`, `optional`):
|
||||
If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
|
||||
:obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
|
||||
answers, are saved in `output_dir`.
|
||||
prefix (:obj:`str`, `optional`):
|
||||
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
|
||||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether this process is the main process or not (used to determine if logging/saves should be done).
|
||||
"""
|
||||
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
|
||||
all_start_logits, all_end_logits = predictions
|
||||
|
||||
assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
features_per_example = collections.defaultdict(list)
|
||||
for i, feature in enumerate(features):
|
||||
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
|
||||
|
||||
# The dictionaries we have to fill.
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
if version_2_with_negative:
|
||||
scores_diff_json = collections.OrderedDict()
|
||||
|
||||
# Logging.
|
||||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
||||
|
||||
# Let's loop over all the examples!
|
||||
for example_index, example in enumerate(tqdm(examples)):
|
||||
# Those are the indices of the features associated to the current example.
|
||||
feature_indices = features_per_example[example_index]
|
||||
|
||||
min_null_prediction = None
|
||||
prelim_predictions = []
|
||||
|
||||
# Looping through all the features associated to the current example.
|
||||
for feature_index in feature_indices:
|
||||
# We grab the predictions of the model for this feature.
|
||||
start_logits = all_start_logits[feature_index]
|
||||
end_logits = all_end_logits[feature_index]
|
||||
# This is what will allow us to map some the positions in our logits to span of texts in the original
|
||||
# context.
|
||||
offset_mapping = features[feature_index]["offset_mapping"]
|
||||
# Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
|
||||
# available in the current feature.
|
||||
token_is_max_context = features[feature_index].get("token_is_max_context", None)
|
||||
|
||||
# Update minimum null prediction.
|
||||
feature_null_score = start_logits[0] + end_logits[0]
|
||||
if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
|
||||
min_null_prediction = {
|
||||
"offsets": (0, 0),
|
||||
"score": feature_null_score,
|
||||
"start_logit": start_logits[0],
|
||||
"end_logit": end_logits[0],
|
||||
}
|
||||
|
||||
# Go through all possibilities for the `n_best_size` greater start and end logits.
|
||||
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
|
||||
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
|
||||
for start_index in start_indexes:
|
||||
for end_index in end_indexes:
|
||||
# Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
|
||||
# to part of the input_ids that are not in the context.
|
||||
if (
|
||||
start_index >= len(offset_mapping)
|
||||
or end_index >= len(offset_mapping)
|
||||
or offset_mapping[start_index] is None
|
||||
or offset_mapping[end_index] is None
|
||||
):
|
||||
continue
|
||||
# Don't consider answers with a length that is either < 0 or > max_answer_length.
|
||||
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
||||
continue
|
||||
# Don't consider answer that don't have the maximum context available (if such information is
|
||||
# provided).
|
||||
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
|
||||
continue
|
||||
prelim_predictions.append(
|
||||
{
|
||||
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
|
||||
"score": start_logits[start_index] + end_logits[end_index],
|
||||
"start_logit": start_logits[start_index],
|
||||
"end_logit": end_logits[end_index],
|
||||
}
|
||||
)
|
||||
if version_2_with_negative:
|
||||
# Add the minimum null prediction
|
||||
prelim_predictions.append(min_null_prediction)
|
||||
null_score = min_null_prediction["score"]
|
||||
|
||||
# Only keep the best `n_best_size` predictions.
|
||||
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
||||
|
||||
# Add back the minimum null prediction if it was removed because of its low score.
|
||||
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
|
||||
predictions.append(min_null_prediction)
|
||||
|
||||
# Use the offsets to gather the answer text in the original context.
|
||||
context = example["context"]
|
||||
for pred in predictions:
|
||||
offsets = pred.pop("offsets")
|
||||
pred["text"] = context[offsets[0] : offsets[1]]
|
||||
|
||||
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
||||
# failure.
|
||||
if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
|
||||
predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})
|
||||
|
||||
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
||||
# the LogSumExp trick).
|
||||
scores = np.array([pred.pop("score") for pred in predictions])
|
||||
exp_scores = np.exp(scores - np.max(scores))
|
||||
probs = exp_scores / exp_scores.sum()
|
||||
|
||||
# Include the probabilities in our predictions.
|
||||
for prob, pred in zip(probs, predictions):
|
||||
pred["probability"] = prob
|
||||
|
||||
# Pick the best prediction. If the null answer is not possible, this is easy.
|
||||
if not version_2_with_negative:
|
||||
all_predictions[example["id"]] = predictions[0]["text"]
|
||||
else:
|
||||
# Otherwise we first need to find the best non-empty prediction.
|
||||
i = 0
|
||||
while predictions[i]["text"] == "":
|
||||
i += 1
|
||||
best_non_null_pred = predictions[i]
|
||||
|
||||
# Then we compare to the null prediction using the threshold.
|
||||
score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
|
||||
scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
|
||||
if score_diff > null_score_diff_threshold:
|
||||
all_predictions[example["id"]] = ""
|
||||
else:
|
||||
all_predictions[example["id"]] = best_non_null_pred["text"]
|
||||
|
||||
# Make `predictions` JSON-serializable by casting np.float back to float.
|
||||
all_nbest_json[example["id"]] = [
|
||||
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
for pred in predictions
|
||||
]
|
||||
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
)
|
||||
nbest_file = os.path.join(
|
||||
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
|
||||
)
|
||||
if version_2_with_negative:
|
||||
null_odds_file = os.path.join(
|
||||
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
|
||||
)
|
||||
|
||||
logger.info(f"Saving predictions to {prediction_file}.")
|
||||
with open(prediction_file, "w") as writer:
|
||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||
logger.info(f"Saving nbest_preds to {nbest_file}.")
|
||||
with open(nbest_file, "w") as writer:
|
||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||
if version_2_with_negative:
|
||||
logger.info(f"Saving null_odds to {null_odds_file}.")
|
||||
with open(null_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
return all_predictions
|
||||
|
||||
|
||||
def postprocess_qa_predictions_with_beam_search(
|
||||
examples,
|
||||
features,
|
||||
predictions: Tuple[np.ndarray, np.ndarray],
|
||||
version_2_with_negative: bool = False,
|
||||
n_best_size: int = 20,
|
||||
max_answer_length: int = 30,
|
||||
start_n_top: int = 5,
|
||||
end_n_top: int = 5,
|
||||
output_dir: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
is_world_process_zero: bool = True,
|
||||
):
|
||||
"""
|
||||
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
|
||||
original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as
|
||||
cls token predictions.
|
||||
|
||||
Args:
|
||||
examples: The non-preprocessed dataset (see the main script for more information).
|
||||
features: The processed dataset (see the main script for more information).
|
||||
predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
|
||||
The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
|
||||
first dimension must match the number of elements of :obj:`features`.
|
||||
version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the underlying dataset contains examples with no answers.
|
||||
n_best_size (:obj:`int`, `optional`, defaults to 20):
|
||||
The total number of n-best predictions to generate when looking for an answer.
|
||||
max_answer_length (:obj:`int`, `optional`, defaults to 30):
|
||||
The maximum length of an answer that can be generated. This is needed because the start and end predictions
|
||||
are not conditioned on one another.
|
||||
start_n_top (:obj:`int`, `optional`, defaults to 5):
|
||||
The number of top start logits too keep when searching for the :obj:`n_best_size` predictions.
|
||||
end_n_top (:obj:`int`, `optional`, defaults to 5):
|
||||
The number of top end logits too keep when searching for the :obj:`n_best_size` predictions.
|
||||
output_dir (:obj:`str`, `optional`):
|
||||
If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
|
||||
:obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
|
||||
answers, are saved in `output_dir`.
|
||||
prefix (:obj:`str`, `optional`):
|
||||
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
|
||||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether this process is the main process or not (used to determine if logging/saves should be done).
|
||||
"""
|
||||
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
|
||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
|
||||
|
||||
assert len(predictions[0]) == len(
|
||||
features
|
||||
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
|
||||
|
||||
# Build a map example to its corresponding features.
|
||||
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
|
||||
features_per_example = collections.defaultdict(list)
|
||||
for i, feature in enumerate(features):
|
||||
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
|
||||
|
||||
# The dictionaries we have to fill.
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
|
||||
|
||||
# Logging.
|
||||
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
|
||||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
||||
|
||||
# Let's loop over all the examples!
|
||||
for example_index, example in enumerate(tqdm(examples)):
|
||||
# Those are the indices of the features associated to the current example.
|
||||
feature_indices = features_per_example[example_index]
|
||||
|
||||
min_null_score = None
|
||||
prelim_predictions = []
|
||||
|
||||
# Looping through all the features associated to the current example.
|
||||
for feature_index in feature_indices:
|
||||
# We grab the predictions of the model for this feature.
|
||||
start_log_prob = start_top_log_probs[feature_index]
|
||||
start_indexes = start_top_index[feature_index]
|
||||
end_log_prob = end_top_log_probs[feature_index]
|
||||
end_indexes = end_top_index[feature_index]
|
||||
feature_null_score = cls_logits[feature_index]
|
||||
# This is what will allow us to map some the positions in our logits to span of texts in the original
|
||||
# context.
|
||||
offset_mapping = features[feature_index]["offset_mapping"]
|
||||
# Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
|
||||
# available in the current feature.
|
||||
token_is_max_context = features[feature_index].get("token_is_max_context", None)
|
||||
|
||||
# Update minimum null prediction
|
||||
if min_null_score is None or feature_null_score < min_null_score:
|
||||
min_null_score = feature_null_score
|
||||
|
||||
# Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits.
|
||||
for i in range(start_n_top):
|
||||
for j in range(end_n_top):
|
||||
start_index = int(start_indexes[i])
|
||||
j_index = i * end_n_top + j
|
||||
end_index = int(end_indexes[j_index])
|
||||
# Don't consider out-of-scope answers (last part of the test should be unnecessary because of the
|
||||
# p_mask but let's not take any risk)
|
||||
if (
|
||||
start_index >= len(offset_mapping)
|
||||
or end_index >= len(offset_mapping)
|
||||
or offset_mapping[start_index] is None
|
||||
or offset_mapping[end_index] is None
|
||||
):
|
||||
continue
|
||||
# Don't consider answers with a length negative or > max_answer_length.
|
||||
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
|
||||
continue
|
||||
# Don't consider answer that don't have the maximum context available (if such information is
|
||||
# provided).
|
||||
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
|
||||
continue
|
||||
prelim_predictions.append(
|
||||
{
|
||||
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
|
||||
"score": start_log_prob[i] + end_log_prob[j_index],
|
||||
"start_log_prob": start_log_prob[i],
|
||||
"end_log_prob": end_log_prob[j_index],
|
||||
}
|
||||
)
|
||||
|
||||
# Only keep the best `n_best_size` predictions.
|
||||
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
|
||||
|
||||
# Use the offsets to gather the answer text in the original context.
|
||||
context = example["context"]
|
||||
for pred in predictions:
|
||||
offsets = pred.pop("offsets")
|
||||
pred["text"] = context[offsets[0] : offsets[1]]
|
||||
|
||||
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
|
||||
# failure.
|
||||
if len(predictions) == 0:
|
||||
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
|
||||
|
||||
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
|
||||
# the LogSumExp trick).
|
||||
scores = np.array([pred.pop("score") for pred in predictions])
|
||||
exp_scores = np.exp(scores - np.max(scores))
|
||||
probs = exp_scores / exp_scores.sum()
|
||||
|
||||
# Include the probabilities in our predictions.
|
||||
for prob, pred in zip(probs, predictions):
|
||||
pred["probability"] = prob
|
||||
|
||||
# Pick the best prediction and set the probability for the null answer.
|
||||
all_predictions[example["id"]] = predictions[0]["text"]
|
||||
if version_2_with_negative:
|
||||
scores_diff_json[example["id"]] = float(min_null_score)
|
||||
|
||||
# Make `predictions` JSON-serializable by casting np.float back to float.
|
||||
all_nbest_json[example["id"]] = [
|
||||
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
for pred in predictions
|
||||
]
|
||||
|
||||
# If we have an output_dir, let's save all those dicts.
|
||||
if output_dir is not None:
|
||||
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
|
||||
|
||||
prediction_file = os.path.join(
|
||||
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
|
||||
)
|
||||
nbest_file = os.path.join(
|
||||
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
|
||||
)
|
||||
if version_2_with_negative:
|
||||
null_odds_file = os.path.join(
|
||||
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
|
||||
)
|
||||
|
||||
print(f"Saving predictions to {prediction_file}.")
|
||||
with open(prediction_file, "w") as writer:
|
||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||
print(f"Saving nbest_preds to {nbest_file}.")
|
||||
with open(nbest_file, "w") as writer:
|
||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||
if version_2_with_negative:
|
||||
print(f"Saving null_odds to {null_odds_file}.")
|
||||
with open(null_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
return all_predictions, scores_diff_json
|
Loading…
Reference in New Issue
Block a user