Expose padding_strategy on squad processor to fix QA pipeline performance regression (#5932)

* Attempt to fix the way squad_convert_examples_to_features pad the elements for the QA pipeline.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Quality

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make the code easier to read and avoid testing multiple test the same thing.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* missing enum value on truncation_strategy.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Rethinking for the easiest fix: expose the padding strategy on squad_convert_examples_to_features.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Remove unused imports.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
Funtowicz Morgan 2020-07-22 16:11:57 +02:00 committed by GitHub
parent ae67b2439f
commit 896300177b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 5 deletions

View File

@ -9,6 +9,7 @@ from tqdm import tqdm
from ...file_utils import is_tf_available, is_torch_available
from ...tokenization_bert import whitespace_tokenize
from ...tokenization_utils_base import TruncationStrategy
from .utils import DataProcessor
@ -87,7 +88,9 @@ def _is_whitespace(c):
return False
def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length, is_training):
def squad_convert_example_to_features(
example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training
):
features = []
if is_training and not example.is_impossible:
# Get start and end position
@ -141,11 +144,21 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
span_doc_tokens = all_doc_tokens
while len(spans) * doc_stride < len(all_doc_tokens):
# Define the side we want to truncate / pad and the text/pair sorting
if tokenizer.padding_side == "right":
texts = truncated_query
pairs = span_doc_tokens
truncation = TruncationStrategy.ONLY_SECOND.value
else:
texts = span_doc_tokens
pairs = truncated_query
truncation = TruncationStrategy.ONLY_FIRST.value
encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
truncation="only_second" if tokenizer.padding_side == "right" else "only_first",
padding="max_length",
texts,
pairs,
truncation=truncation,
padding=padding_strategy,
max_length=max_seq_length,
return_overflowing_tokens=True,
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
@ -285,6 +298,7 @@ def squad_convert_examples_to_features(
doc_stride,
max_query_length,
is_training,
padding_strategy="max_length",
return_dataset=False,
threads=1,
tqdm_enabled=True,
@ -300,6 +314,7 @@ def squad_convert_examples_to_features(
doc_stride: The stride used when the context is too large and is split across several features.
max_query_length: The maximum length of the query.
is_training: whether to create features for model evaluation or model training.
padding_strategy: Default to "max_length". Which padding strategy to use
return_dataset: Default False. Either 'pt' or 'tf'.
if 'pt': returns a torch.data.TensorDataset,
if 'tf': returns a tf.data.Dataset
@ -333,6 +348,7 @@ def squad_convert_examples_to_features(
max_seq_length=max_seq_length,
doc_stride=doc_stride,
max_query_length=max_query_length,
padding_strategy=padding_strategy,
is_training=is_training,
)
features = list(

View File

@ -36,6 +36,7 @@ from .modelcard import ModelCard
from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils_base import PaddingStrategy
if is_tf_available():
@ -1318,6 +1319,7 @@ class QuestionAnsweringPipeline(Pipeline):
max_seq_length=kwargs["max_seq_len"],
doc_stride=kwargs["doc_stride"],
max_query_length=kwargs["max_question_len"],
padding_strategy=PaddingStrategy.DO_NOT_PAD.value,
is_training=False,
tqdm_enabled=False,
)