mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Expose padding_strategy on squad processor to fix QA pipeline performance regression (#5932)
* Attempt to fix the way squad_convert_examples_to_features pad the elements for the QA pipeline. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Quality Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make the code easier to read and avoid testing multiple test the same thing. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * missing enum value on truncation_strategy. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rethinking for the easiest fix: expose the padding strategy on squad_convert_examples_to_features. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
parent
ae67b2439f
commit
896300177b
@ -9,6 +9,7 @@ from tqdm import tqdm
|
||||
|
||||
from ...file_utils import is_tf_available, is_torch_available
|
||||
from ...tokenization_bert import whitespace_tokenize
|
||||
from ...tokenization_utils_base import TruncationStrategy
|
||||
from .utils import DataProcessor
|
||||
|
||||
|
||||
@ -87,7 +88,9 @@ def _is_whitespace(c):
|
||||
return False
|
||||
|
||||
|
||||
def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training):
|
||||
def squad_convert_example_to_features(
|
||||
example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training
|
||||
):
|
||||
features = []
|
||||
if is_training and not example.is_impossible:
|
||||
# Get start and end position
|
||||
@ -141,11 +144,21 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
|
||||
span_doc_tokens = all_doc_tokens
|
||||
while len(spans) * doc_stride < len(all_doc_tokens):
|
||||
|
||||
# Define the side we want to truncate / pad and the text/pair sorting
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = truncated_query
|
||||
pairs = span_doc_tokens
|
||||
truncation = TruncationStrategy.ONLY_SECOND.value
|
||||
else:
|
||||
texts = span_doc_tokens
|
||||
pairs = truncated_query
|
||||
truncation = TruncationStrategy.ONLY_FIRST.value
|
||||
|
||||
encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
|
||||
truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
|
||||
span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
|
||||
truncation="only_second" if tokenizer.padding_side == "right" else "only_first",
|
||||
padding="max_length",
|
||||
texts,
|
||||
pairs,
|
||||
truncation=truncation,
|
||||
padding=padding_strategy,
|
||||
max_length=max_seq_length,
|
||||
return_overflowing_tokens=True,
|
||||
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
||||
@ -285,6 +298,7 @@ def squad_convert_examples_to_features(
|
||||
doc_stride,
|
||||
max_query_length,
|
||||
is_training,
|
||||
padding_strategy="max_length",
|
||||
return_dataset=False,
|
||||
threads=1,
|
||||
tqdm_enabled=True,
|
||||
@ -300,6 +314,7 @@ def squad_convert_examples_to_features(
|
||||
doc_stride: The stride used when the context is too large and is split across several features.
|
||||
max_query_length: The maximum length of the query.
|
||||
is_training: whether to create features for model evaluation or model training.
|
||||
padding_strategy: Default to "max_length". Which padding strategy to use
|
||||
return_dataset: Default False. Either 'pt' or 'tf'.
|
||||
if 'pt': returns a torch.data.TensorDataset,
|
||||
if 'tf': returns a tf.data.Dataset
|
||||
@ -333,6 +348,7 @@ def squad_convert_examples_to_features(
|
||||
max_seq_length=max_seq_length,
|
||||
doc_stride=doc_stride,
|
||||
max_query_length=max_query_length,
|
||||
padding_strategy=padding_strategy,
|
||||
is_training=is_training,
|
||||
)
|
||||
features = list(
|
||||
|
@ -36,6 +36,7 @@ from .modelcard import ModelCard
|
||||
from .tokenization_auto import AutoTokenizer
|
||||
from .tokenization_bert import BasicTokenizer
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils_base import PaddingStrategy
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@ -1318,6 +1319,7 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
max_seq_length=kwargs["max_seq_len"],
|
||||
doc_stride=kwargs["doc_stride"],
|
||||
max_query_length=kwargs["max_question_len"],
|
||||
padding_strategy=PaddingStrategy.DO_NOT_PAD.value,
|
||||
is_training=False,
|
||||
tqdm_enabled=False,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user