Add min and max question length options to TapasTokenizer (#12803)

* Add min and max question length option to the tokenizer

* Add corresponding test
This commit is contained in:
NielsRogge 2021-08-23 09:44:42 +02:00 committed by GitHub
parent 588e6caa15
commit 8679bd7144
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 4 deletions

View File

@ -262,7 +262,10 @@ class TapasTokenizer(PreTrainedTokenizer):
Whether to add empty strings instead of column names.
update_answer_coordinates (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to recompute the answer coordinates from the answer text.
min_question_length (:obj:`int`, `optional`):
Minimum length of each question in terms of tokens (will be skipped otherwise).
max_question_length (:obj:`int`, `optional`):
Maximum length of each question in terms of tokens (will be skipped otherwise).
"""
vocab_files_names = VOCAB_FILES_NAMES
@ -288,6 +291,8 @@ class TapasTokenizer(PreTrainedTokenizer):
max_row_id: int = None,
strip_column_names: bool = False,
update_answer_coordinates: bool = False,
min_question_length=None,
max_question_length=None,
model_max_length: int = 512,
additional_special_tokens: Optional[List[str]] = None,
**kwargs
@ -318,6 +323,8 @@ class TapasTokenizer(PreTrainedTokenizer):
max_row_id=max_row_id,
strip_column_names=strip_column_names,
update_answer_coordinates=update_answer_coordinates,
min_question_length=min_question_length,
max_question_length=max_question_length,
model_max_length=model_max_length,
additional_special_tokens=additional_special_tokens,
**kwargs,
@ -346,6 +353,8 @@ class TapasTokenizer(PreTrainedTokenizer):
self.max_row_id = max_row_id if max_row_id is not None else self.model_max_length
self.strip_column_names = strip_column_names
self.update_answer_coordinates = update_answer_coordinates
self.min_question_length = min_question_length
self.max_question_length = max_question_length
@property
def do_lower_case(self):
@ -729,6 +738,19 @@ class TapasTokenizer(PreTrainedTokenizer):
**kwargs,
)
def _get_question_tokens(self, query):
"""Tokenizes the query, taking into account the max and min question length."""
query_tokens = self.tokenize(query)
if self.max_question_length is not None and len(query_tokens) > self.max_question_length:
logger.warning("Skipping query as its tokens are longer than the max question length")
return "", []
if self.min_question_length is not None and len(query_tokens) < self.min_question_length:
logger.warning("Skipping query as its tokens are shorter than the min question length")
return "", []
return query, query_tokens
def _batch_encode_plus(
self,
table,
@ -757,8 +779,9 @@ class TapasTokenizer(PreTrainedTokenizer):
table_tokens = self._tokenize_table(table)
queries_tokens = []
for query in queries:
query_tokens = self.tokenize(query)
for idx, query in enumerate(queries):
query, query_tokens = self._get_question_tokens(query)
queries[idx] = query
queries_tokens.append(query_tokens)
batch_outputs = self._batch_prepare_for_model(
@ -1015,7 +1038,7 @@ class TapasTokenizer(PreTrainedTokenizer):
)
table_tokens = self._tokenize_table(table)
query_tokens = self.tokenize(query)
query, query_tokens = self._get_question_tokens(query)
return self.prepare_for_model(
table,

View File

@ -1076,6 +1076,37 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(new_encoded_inputs, dropped_encoded_inputs)
self.assertLessEqual(len(new_encoded_inputs), 20)
@slow
def test_min_max_question_length(self):
data = {
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
"Age": ["56", "45", "59"],
"Number of movies": ["87", "53", "69"],
"Date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"],
}
queries = "When was Brad Pitt born?"
table = pd.DataFrame.from_dict(data)
# test max_question_length
tokenizer = TapasTokenizer.from_pretrained("lysandre/tapas-temporary-repo", max_question_length=2)
encoding = tokenizer(table=table, queries=queries)
# query should not be tokenized as it's longer than the specified max_question_length
expected_results = [101, 102]
self.assertListEqual(encoding.input_ids[:2], expected_results)
# test min_question_length
tokenizer = TapasTokenizer.from_pretrained("lysandre/tapas-temporary-repo", min_question_length=30)
encoding = tokenizer(table=table, queries=queries)
# query should not be tokenized as it's shorter than the specified min_question_length
expected_results = [101, 102]
self.assertListEqual(encoding.input_ids[:2], expected_results)
@is_pt_tf_cross_test
def test_batch_encode_plus_tensors(self):
tokenizers = self.get_tokenizers(do_lower_case=False)