mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Moving question_answering tests to the new testing scheme. Had to tweak a little some ModelTesterConfig for pipelines. (#13277)
* Moving question_answering tests to the new testing scheme. Had to tweak a little some ModelTesterConfig for pipelines. * Removing commented code.
This commit is contained in:
parent
4fa1cd995c
commit
55fb88d369
@ -202,7 +202,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
- **answer** (:obj:`str`) -- The answer to the question.
|
||||
"""
|
||||
# Set defaults values
|
||||
kwargs.setdefault("padding", "longest")
|
||||
kwargs.setdefault("padding", "longest" if getattr(self.tokenizer, "pad_token", None) is not None else False)
|
||||
kwargs.setdefault("topk", 1)
|
||||
kwargs.setdefault("doc_stride", 128)
|
||||
kwargs.setdefault("max_answer_len", 15)
|
||||
@ -353,17 +353,17 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
# Start: Index of the first character of the answer in the context string
|
||||
# End: Index of the character following the last character of the answer in the context string
|
||||
# Answer: Plain text of the answer
|
||||
answers += [
|
||||
{
|
||||
"score": score.item(),
|
||||
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
|
||||
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
|
||||
"answer": " ".join(
|
||||
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
|
||||
),
|
||||
}
|
||||
for s, e, score in zip(starts, ends, scores)
|
||||
]
|
||||
for s, e, score in zip(starts, ends, scores):
|
||||
answers.append(
|
||||
{
|
||||
"score": score.item(),
|
||||
"start": np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
|
||||
"end": np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
|
||||
"answer": " ".join(
|
||||
example.doc_tokens[feature.token_to_orig_map[s] : feature.token_to_orig_map[e] + 1]
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Convert the answer (tokens) back to the original text
|
||||
# Score: score from the model
|
||||
@ -376,25 +376,26 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
# Sometimes the max probability token is in the middle of a word so:
|
||||
# - we start by finding the right word containing the token with `token_to_word`
|
||||
# - then we convert this word in a character span with `word_to_chars`
|
||||
answers += [
|
||||
{
|
||||
"score": score.item(),
|
||||
"start": enc.word_to_chars(
|
||||
enc.token_to_word(s), sequence_index=1 if question_first else 0
|
||||
)[0],
|
||||
"end": enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[
|
||||
1
|
||||
],
|
||||
"answer": example.context_text[
|
||||
enc.word_to_chars(enc.token_to_word(s), sequence_index=1 if question_first else 0)[
|
||||
0
|
||||
] : enc.word_to_chars(enc.token_to_word(e), sequence_index=1 if question_first else 0)[
|
||||
1
|
||||
]
|
||||
],
|
||||
}
|
||||
for s, e, score in zip(starts, ends, scores)
|
||||
]
|
||||
sequence_index = 1 if question_first else 0
|
||||
for s, e, score in zip(starts, ends, scores):
|
||||
try:
|
||||
start_word = enc.token_to_word(s)
|
||||
end_word = enc.token_to_word(e)
|
||||
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
|
||||
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
|
||||
except Exception:
|
||||
# Some tokenizers don't really handle words. Keep to offsets then.
|
||||
start_index = enc.offsets[s][0]
|
||||
end_index = enc.offsets[e][1]
|
||||
|
||||
answers.append(
|
||||
{
|
||||
"score": score.item(),
|
||||
"start": start_index,
|
||||
"end": end_index,
|
||||
"answer": example.context_text[start_index:end_index],
|
||||
}
|
||||
)
|
||||
|
||||
if kwargs["handle_impossible_answer"]:
|
||||
answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
|
||||
|
@ -147,6 +147,11 @@ class BartModelTester:
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.max_position_embeddings = 100
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
return config, inputs_dict
|
||||
|
@ -189,6 +189,7 @@ class ReformerModelTester:
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 100
|
||||
config.axial_pos_shape = (4, 25)
|
||||
config.is_decoder = False
|
||||
return config
|
||||
|
||||
|
@ -14,107 +14,126 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers import (
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
LxmertConfig,
|
||||
QuestionAnsweringPipeline,
|
||||
)
|
||||
from transformers.data.processors.squad import SquadExample
|
||||
from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler, pipeline
|
||||
from transformers.testing_utils import slow
|
||||
from transformers.pipelines import QuestionAnsweringArgumentHandler, pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
class QAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "question-answering"
|
||||
pipeline_running_kwargs = {
|
||||
"padding": "max_length",
|
||||
"max_seq_len": 25,
|
||||
"doc_stride": 5,
|
||||
} # Default is 'longest' but we use 'max_length' to test equivalence between slow/fast tokenizers
|
||||
small_models = [
|
||||
"sshleifer/tiny-distilbert-base-cased-distilled-squad"
|
||||
] # Models tested without the @slow decorator
|
||||
large_models = [] # Models tested with the @slow decorator
|
||||
valid_inputs = [
|
||||
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
|
||||
{
|
||||
"question": "In what field is HuggingFace working ?",
|
||||
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||
},
|
||||
{
|
||||
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
|
||||
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||
},
|
||||
{
|
||||
"question": ["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
|
||||
"context": [
|
||||
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||
"HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||
@is_pipeline_test
|
||||
class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
def run_pipeline_test(self, model, tokenizer, feature_extractor):
|
||||
if isinstance(model.config, LxmertConfig):
|
||||
# This is an bimodal model, we need to find a more consistent way
|
||||
# to switch on those models.
|
||||
return
|
||||
question_answerer = QuestionAnsweringPipeline(model, tokenizer)
|
||||
|
||||
outputs = question_answerer(
|
||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||
)
|
||||
self.assertEqual(outputs, {"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)})
|
||||
|
||||
outputs = question_answerer(
|
||||
question=["In what field is HuggingFace working ?", "In what field is HuggingFace working ?"],
|
||||
context="HuggingFace was founded in Paris.",
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)},
|
||||
{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)},
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
def get_pipelines(self):
|
||||
question_answering_pipelines = [
|
||||
pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model,
|
||||
tokenizer=model,
|
||||
framework="pt" if is_torch_available() else "tf",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
for model in self.small_models
|
||||
]
|
||||
return question_answering_pipelines
|
||||
outputs = question_answerer(
|
||||
question=["What field is HuggingFace working ?", "In what field is HuggingFace ?"],
|
||||
context=[
|
||||
"HuggingFace is a startup based in New-York",
|
||||
"HuggingFace is a startup founded in Paris",
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)},
|
||||
{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)},
|
||||
],
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
question_answerer(question="", context="HuggingFace was founded in Paris.")
|
||||
with self.assertRaises(ValueError):
|
||||
question_answerer(question=None, context="HuggingFace was founded in Paris.")
|
||||
with self.assertRaises(ValueError):
|
||||
question_answerer(question="In what field is HuggingFace working ?", context="")
|
||||
with self.assertRaises(ValueError):
|
||||
question_answerer(question="In what field is HuggingFace working ?", context=None)
|
||||
|
||||
outputs = question_answerer(
|
||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris.", topk=20
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs, [{"answer": ANY(str), "start": ANY(int), "end": ANY(int), "score": ANY(float)} for i in range(20)]
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
question_answerer = pipeline(
|
||||
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad"
|
||||
)
|
||||
outputs = question_answerer(
|
||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||
)
|
||||
|
||||
self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
question_answerer = pipeline(
|
||||
"question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad", framework="tf"
|
||||
)
|
||||
outputs = question_answerer(
|
||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||
)
|
||||
|
||||
self.assertEqual(nested_simplify(outputs), {"score": 0.011, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(not is_torch_available() and not is_tf_available(), "Either torch or TF must be installed.")
|
||||
def test_high_topk_small_context(self):
|
||||
self.pipeline_running_kwargs.update({"topk": 20})
|
||||
valid_inputs = [
|
||||
{"question": "Where was HuggingFace founded ?", "context": "Paris"},
|
||||
]
|
||||
question_answering_pipelines = self.get_pipelines()
|
||||
output_keys = {"score", "answer", "start", "end"}
|
||||
for question_answering_pipeline in question_answering_pipelines:
|
||||
result = question_answering_pipeline(valid_inputs, **self.pipeline_running_kwargs)
|
||||
self.assertIsInstance(result, dict)
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
question_answerer = pipeline(
|
||||
"question-answering",
|
||||
)
|
||||
outputs = question_answerer(
|
||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||
)
|
||||
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
self.assertEqual(nested_simplify(outputs), {"score": 0.979, "start": 27, "end": 32, "answer": "Paris"})
|
||||
|
||||
def _test_pipeline(self, question_answering_pipeline: Pipeline):
|
||||
output_keys = {"score", "answer", "start", "end"}
|
||||
valid_inputs = [
|
||||
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
|
||||
{
|
||||
"question": "In what field is HuggingFace working ?",
|
||||
"context": "HuggingFace is a startup based in New-York founded in Paris which is trying to solve NLP.",
|
||||
},
|
||||
]
|
||||
invalid_inputs = [
|
||||
{"question": "", "context": "This is a test to try empty question edge case"},
|
||||
{"question": None, "context": "This is a test to try empty question edge case"},
|
||||
{"question": "What is does with empty context ?", "context": ""},
|
||||
{"question": "What is does with empty context ?", "context": None},
|
||||
]
|
||||
self.assertIsNotNone(question_answering_pipeline)
|
||||
@slow
|
||||
@require_tf
|
||||
def test_large_model_tf(self):
|
||||
question_answerer = pipeline("question-answering", framework="tf")
|
||||
outputs = question_answerer(
|
||||
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
|
||||
)
|
||||
|
||||
mono_result = question_answering_pipeline(valid_inputs[0])
|
||||
self.assertIsInstance(mono_result, dict)
|
||||
self.assertEqual(nested_simplify(outputs), {"score": 0.979, "start": 27, "end": 32, "answer": "Paris"})
|
||||
|
||||
for key in output_keys:
|
||||
self.assertIn(key, mono_result)
|
||||
|
||||
multi_result = question_answering_pipeline(valid_inputs)
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], dict)
|
||||
|
||||
for result in multi_result:
|
||||
for key in output_keys:
|
||||
self.assertIn(key, result)
|
||||
for bad_input in invalid_inputs:
|
||||
self.assertRaises(ValueError, question_answering_pipeline, bad_input)
|
||||
self.assertRaises(ValueError, question_answering_pipeline, invalid_inputs)
|
||||
|
||||
@is_pipeline_test
|
||||
class QuestionAnsweringArgumentHandlerTests(unittest.TestCase):
|
||||
def test_argument_handler(self):
|
||||
qa = QuestionAnsweringArgumentHandler()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user