Moving question_answering tests to the new testing scheme. Had to tweak a little some ModelTesterConfig for pipelines. (#13277)

* Moving question_answering tests to the new testing scheme. Had to tweak
a little some ModelTesterConfig for pipelines.

* Removing commented code.
This commit is contained in:
Nicolas Patry 2021-08-26 12:37:55 +02:00 committed by GitHub
parent 4fa1cd995c
commit 55fb88d369
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 119 deletions

View File

@ -202,7 +202,7 @@ class QuestionAnsweringPipeline(Pipeline):
- **answer** (:obj:`str`) -- The answer to the question.
"""
# Set defaults values
kwargs.setdefault("padding", "longest")
kwargs.setdefault("padding", "longest" if getattr(self.tokenizer, "pad_token", None) is not None else False)
kwargs.setdefault("topk", 1)
kwargs.setdefault("doc_stride", 128)
kwargs.setdefault("max_answer_len", 15)
@ -353,17 +353,17 @@ class QuestionAnsweringPipeline(Pipeline):
# Start: Index of the first character of the answer in the context string
# End: Index of the character following the last character of the answer in the context string
# Answer: Plain text of the answer
answers += [
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
for s, e, score in zip(starts, ends, scores)
]
for s, e, score in zip(starts, ends, scores):
answers.append(
{
"score": score.item(),
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
"answer": " ".join(
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
),
}
)
else:
# Convert the answer (tokens) back to the original text
# Score: score from the model
@ -376,25 +376,26 @@ class QuestionAnsweringPipeline(Pipeline):
# Sometimes the max probability token is in the middle of a word so:
# - we start by finding the right word containing the token with `token_to_word`
# - then we convert this word in a character span with `word_to_chars`
answers += [
{
"score": score.item(),
"start": enc.word_to_chars(
enc.token_to_word(s), sequence_index=1 if question_first else 0
)[0],
"end": enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[
1
],
"answer": example.context_text[
enc.word_to_chars(enc.token_to_word(s), sequence_index=1 if question_first else 0)[
0
] : enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[
1
]
],
}
for s, e, score in zip(starts, ends, scores)
]
sequence_index = 1 if question_first else 0
for s, e, score in zip(starts, ends, scores):
try:
start_word = enc.token_to_word(s)
end_word = enc.token_to_word(e)
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
except Exception:
# Some tokenizers don't really handle words. Keep to offsets then.
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]
answers.append(
{
"score": score.item(),
"start": start_index,
"end": end_index,
"answer": example.context_text[start_index:end_index],
}
)
if kwargs["handle_impossible_answer"]:
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})

View File

@ -147,6 +147,11 @@ class BartModelTester:
pad_token_id=self.pad_token_id,
)
def get_pipeline_config(self):
config = self.get_config()
config.max_position_embeddings = 100
return config
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict

View File

@ -189,6 +189,7 @@ class ReformerModelTester:
def get_pipeline_config(self):
config = self.get_config()
config.vocab_size = 100
config.axial_pos_shape = (4, 25)
config.is_decoder = False
return config

View File

@ -14,107 +14,126 @@
import unittest
from transformers import is_tf_available, is_torch_available
from transformers import (
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
LxmertConfig,
QuestionAnsweringPipeline,
)
from transformers.data.processors.squad import SquadExample
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler, pipeline
from transformers.testing_utils import slow
from transformers.pipelines import QuestionAnsweringArgumentHandler, pipeline
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from .test_pipelines_common import CustomInputPipelineCommonMixin
from .test_pipelines_common import ANY, PipelineTestCaseMeta
class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "question-answering"
pipeline_running_kwargs = {
"padding": "max_length",
"max_seq_len": 25,
"doc_stride": 5,
} # Default is 'longest' but we use 'max_length' to test equivalence between slow/fast tokenizers
small_models = [
"sshleifer/tiny-distilbert-base-cased-distilled-squad"
] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator
valid_inputs = [
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
{
"question": "In what field is HuggingFace working ?",
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
},
{
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
},
{
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
"context": [
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
@is_pipeline_test
class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
tf_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
def run_pipeline_test(self, model, tokenizer, feature_extractor):
if isinstance(model.config, LxmertConfig):
# This is an bimodal model, we need to find a more consistent way
# to switch on those models.
return
question_answerer = QuestionAnsweringPipeline(model, tokenizer)
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
outputs = question_answerer(
question=["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
context="HuggingFace was founded in Paris.",
)
self.assertEqual(
outputs,
[
{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)},
{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)},
],
},
]
)
def get_pipelines(self):
question_answering_pipelines = [
pipeline(
task=self.pipeline_task,
model=model,
tokenizer=model,
framework="pt" if is_torch_available() else "tf",
**self.pipeline_loading_kwargs,
)
for model in self.small_models
]
return question_answering_pipelines
outputs = question_answerer(
question=["What field is HuggingFace working ?", "In what field is HuggingFace ?"],
context=[
"HuggingFace is a startup based in New-York",
"HuggingFace is a startup founded in Paris",
],
)
self.assertEqual(
outputs,
[
{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)},
{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)},
],
)
with self.assertRaises(ValueError):
question_answerer(question="", context="HuggingFace was founded in Paris.")
with self.assertRaises(ValueError):
question_answerer(question=None, context="HuggingFace was founded in Paris.")
with self.assertRaises(ValueError):
question_answerer(question="In what field is HuggingFace working ?", context="")
with self.assertRaises(ValueError):
question_answerer(question="In what field is HuggingFace working ?", context=None)
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris.", topk=20
)
self.assertEqual(
outputs, [{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)} for i in range(20)]
)
@require_torch
def test_small_model_pt(self):
question_answerer = pipeline(
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
)
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
@require_tf
def test_small_model_tf(self):
question_answerer = pipeline(
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad", framework="tf"
)
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)
self.assertEqual(nested_simplify(outputs), {"score": 0.011, "start": 0, "end": 11, "answer": "HuggingFace"})
@slow
@unittest.skipIf(not is_torch_available() and not is_tf_available(), "Either torch or TF must be installed.")
def test_high_topk_small_context(self):
self.pipeline_running_kwargs.update({"topk": 20})
valid_inputs = [
{"question": "Where was HuggingFace founded ?", "context": "Paris"},
]
question_answering_pipelines = self.get_pipelines()
output_keys = {"score", "answer", "start", "end"}
for question_answering_pipeline in question_answering_pipelines:
result = question_answering_pipeline(valid_inputs, **self.pipeline_running_kwargs)
self.assertIsInstance(result, dict)
@require_torch
def test_large_model_pt(self):
question_answerer = pipeline(
"question-answering",
)
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)
for key in output_keys:
self.assertIn(key, result)
self.assertEqual(nested_simplify(outputs), {"score": 0.979, "start": 27, "end": 32, "answer": "Paris"})
def _test_pipeline(self, question_answering_pipeline: Pipeline):
output_keys = {"score", "answer", "start", "end"}
valid_inputs = [
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
{
"question": "In what field is HuggingFace working ?",
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
},
]
invalid_inputs = [
{"question": "", "context": "This is a test to try empty question edge case"},
{"question": None, "context": "This is a test to try empty question edge case"},
{"question": "What is does with empty context ?", "context": ""},
{"question": "What is does with empty context ?", "context": None},
]
self.assertIsNotNone(question_answering_pipeline)
@slow
@require_tf
def test_large_model_tf(self):
question_answerer = pipeline("question-answering", framework="tf")
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)
mono_result = question_answering_pipeline(valid_inputs[0])
self.assertIsInstance(mono_result, dict)
self.assertEqual(nested_simplify(outputs), {"score": 0.979, "start": 27, "end": 32, "answer": "Paris"})
for key in output_keys:
self.assertIn(key, mono_result)
multi_result = question_answering_pipeline(valid_inputs)
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], dict)
for result in multi_result:
for key in output_keys:
self.assertIn(key, result)
for bad_input in invalid_inputs:
self.assertRaises(ValueError, question_answering_pipeline, bad_input)
self.assertRaises(ValueError, question_answering_pipeline, invalid_inputs)
@is_pipeline_test
class QuestionAnsweringArgumentHandlerTests(unittest.TestCase):
def test_argument_handler(self):
qa = QuestionAnsweringArgumentHandler()