mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Adding a new align_to_words
param to qa pipeline. (#18010)
* Adding a new `align_to_words` param to qa pipeline. * Update src/transformers/pipelines/question_answering.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Import protection. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
ab2006e3d6
commit
9f5fe63548
@ -8,7 +8,14 @@ import numpy as np
|
|||||||
from ..data import SquadExample, SquadFeatures, squad_convert_examples_to_features
|
from ..data import SquadExample, SquadFeatures, squad_convert_examples_to_features
|
||||||
from ..modelcard import ModelCard
|
from ..modelcard import ModelCard
|
||||||
from ..tokenization_utils import PreTrainedTokenizer
|
from ..tokenization_utils import PreTrainedTokenizer
|
||||||
from ..utils import PaddingStrategy, add_end_docstrings, is_tf_available, is_torch_available, logging
|
from ..utils import (
|
||||||
|
PaddingStrategy,
|
||||||
|
add_end_docstrings,
|
||||||
|
is_tf_available,
|
||||||
|
is_tokenizers_available,
|
||||||
|
is_torch_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline
|
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +25,9 @@ if TYPE_CHECKING:
|
|||||||
from ..modeling_tf_utils import TFPreTrainedModel
|
from ..modeling_tf_utils import TFPreTrainedModel
|
||||||
from ..modeling_utils import PreTrainedModel
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
if is_tokenizers_available():
|
||||||
|
import tokenizers
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -180,6 +190,7 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
max_seq_len=None,
|
max_seq_len=None,
|
||||||
max_question_len=None,
|
max_question_len=None,
|
||||||
handle_impossible_answer=None,
|
handle_impossible_answer=None,
|
||||||
|
align_to_words=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# Set defaults values
|
# Set defaults values
|
||||||
@ -208,6 +219,8 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
postprocess_params["max_answer_len"] = max_answer_len
|
postprocess_params["max_answer_len"] = max_answer_len
|
||||||
if handle_impossible_answer is not None:
|
if handle_impossible_answer is not None:
|
||||||
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
|
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
|
||||||
|
if align_to_words is not None:
|
||||||
|
postprocess_params["align_to_words"] = align_to_words
|
||||||
return preprocess_params, {}, postprocess_params
|
return preprocess_params, {}, postprocess_params
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
@ -243,6 +256,9 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
The maximum length of the question after tokenization. It will be truncated if needed.
|
The maximum length of the question after tokenization. It will be truncated if needed.
|
||||||
handle_impossible_answer (`bool`, *optional*, defaults to `False`):
|
handle_impossible_answer (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not we accept impossible as an answer.
|
Whether or not we accept impossible as an answer.
|
||||||
|
align_to_words (`bool`, *optional*, defaults to `True`):
|
||||||
|
Attempts to align the answer to real words. Improves quality on space separated langages. Might hurt on
|
||||||
|
non-space-separated languages (like Japanese or Chinese)
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
|
A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
|
||||||
@ -386,6 +402,7 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
top_k=1,
|
top_k=1,
|
||||||
handle_impossible_answer=False,
|
handle_impossible_answer=False,
|
||||||
max_answer_len=15,
|
max_answer_len=15,
|
||||||
|
align_to_words=True,
|
||||||
):
|
):
|
||||||
min_null_score = 1000000 # large and positive
|
min_null_score = 1000000 # large and positive
|
||||||
answers = []
|
answers = []
|
||||||
@ -464,15 +481,8 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
for s, e, score in zip(starts, ends, scores):
|
for s, e, score in zip(starts, ends, scores):
|
||||||
s = s - offset
|
s = s - offset
|
||||||
e = e - offset
|
e = e - offset
|
||||||
try:
|
|
||||||
start_word = enc.token_to_word(s)
|
start_index, end_index = self.get_indices(enc, s, e, sequence_index, align_to_words)
|
||||||
end_word = enc.token_to_word(e)
|
|
||||||
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
|
|
||||||
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
|
|
||||||
except Exception:
|
|
||||||
# Some tokenizers don't really handle words. Keep to offsets then.
|
|
||||||
start_index = enc.offsets[s][0]
|
|
||||||
end_index = enc.offsets[e][1]
|
|
||||||
|
|
||||||
answers.append(
|
answers.append(
|
||||||
{
|
{
|
||||||
@ -490,6 +500,24 @@ class QuestionAnsweringPipeline(ChunkPipeline):
|
|||||||
return answers[0]
|
return answers[0]
|
||||||
return answers
|
return answers
|
||||||
|
|
||||||
|
def get_indices(
|
||||||
|
self, enc: "tokenizers.Encoding", s: int, e: int, sequence_index: int, align_to_words: bool
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
if align_to_words:
|
||||||
|
try:
|
||||||
|
start_word = enc.token_to_word(s)
|
||||||
|
end_word = enc.token_to_word(e)
|
||||||
|
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
|
||||||
|
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
|
||||||
|
except Exception:
|
||||||
|
# Some tokenizers don't really handle words. Keep to offsets then.
|
||||||
|
start_index = enc.offsets[s][0]
|
||||||
|
end_index = enc.offsets[e][1]
|
||||||
|
else:
|
||||||
|
start_index = enc.offsets[s][0]
|
||||||
|
end_index = enc.offsets[e][1]
|
||||||
|
return start_index, end_index
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray
|
self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray
|
||||||
) -> Tuple:
|
) -> Tuple:
|
||||||
|
@ -171,6 +171,29 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
|||||||
|
|
||||||
self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"})
|
self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"})
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_japanese(self):
|
||||||
|
question_answerer = pipeline(
|
||||||
|
"question-answering",
|
||||||
|
model="KoichiYasuoka/deberta-base-japanese-aozora-ud-head",
|
||||||
|
)
|
||||||
|
output = question_answerer(question="国語", context="全学年にわたって小学校の国語の教科書に挿し絵が用いられている")
|
||||||
|
|
||||||
|
# Wrong answer, the whole text is identified as one "word" since the tokenizer does not include
|
||||||
|
# a pretokenizer
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
{"score": 1.0, "start": 0, "end": 30, "answer": "全学年にわたって小学校の国語の教科書に挿し絵が用いられている"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable word alignment
|
||||||
|
output = question_answerer(question="国語", context="全学年にわたって小学校の国語の教科書に挿し絵が用いられている", align_to_words=False)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(output),
|
||||||
|
{"score": 1.0, "start": 15, "end": 18, "answer": "教科書"},
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_long_context_cls_slow(self):
|
def test_small_model_long_context_cls_slow(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user