Preserve spaces in GPT-2 tokenizers (#2778)

* Preserve spaces in GPT-2 tokenizers

Preserves spaces after special tokens in GPT-2 and inhereted (RoBERTa)
tokenizers, enabling correct BPE encoding. Automatically inserts a space
in front of first token in encode function when adding special tokens.

* Add tokenization preprocessing method

* Add framework argument to pipeline factory

Also fixes pipeline test issue. Each test input now treated as a
distinct sequence.
This commit is contained in:
Joe Davison 2020-02-13 13:29:43 -05:00 committed by GitHub
parent 0ed630f139
commit f1e8a51f08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 41 deletions

View File

@ -1001,6 +1001,7 @@ def pipeline(
config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
modelcard: Optional[Union[str, ModelCard]] = None,
framework: Optional[str] = None,
**kwargs
) -> Pipeline:
"""
@ -1021,7 +1022,7 @@ def pipeline(
if task not in SUPPORTED_TASKS:
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
framework = get_framework(model)
framework = framework or get_framework(model)
targeted_task = SUPPORTED_TASKS[task]
task, model_class = targeted_task["impl"], targeted_task[framework]

View File

@ -191,15 +191,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
self.cache[token] = word
return word
def _tokenize(self, text, add_prefix_space=False):
""" Tokenize a string.
Args:
- add_prefix_space (boolean, default False):
Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
"""
if add_prefix_space:
text = " " + text
def _tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
@ -248,6 +241,11 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return vocab_file, merge_file
def prepare_for_tokenization(self, text, **kwargs):
if "add_prefix_space" in kwargs and kwargs["add_prefix_space"]:
return " " + text
return text
class GPT2TokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES

View File

@ -154,3 +154,12 @@ class RobertaTokenizer(GPT2Tokenizer):
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def prepare_for_tokenization(self, text, add_special_tokens=False, **kwargs):
if "add_prefix_space" in kwargs:
add_prefix_space = kwargs["add_prefix_space"]
else:
add_prefix_space = add_special_tokens
if add_prefix_space and not text[0].isspace():
text = " " + text
return text

View File

@ -662,9 +662,12 @@ class PreTrainedTokenizer(object):
Take care of added tokens.
text: The sequence to be encoded.
**kwargs: passed to the child `self.tokenize()` method
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
**kwargs: passed to the `prepare_for_tokenization` preprocessing method.
"""
all_special_tokens = self.all_special_tokens
text = self.prepare_for_tokenization(text, **kwargs)
def lowercase_text(t):
# convert non-special tokens to lowercase
@ -679,7 +682,7 @@ class PreTrainedTokenizer(object):
result = []
split_text = text.split(tok)
for i, sub_text in enumerate(split_text):
sub_text = sub_text.strip()
sub_text = sub_text.rstrip()
if i == 0 and not sub_text:
result += [tok]
elif i == len(split_text) - 1:
@ -697,7 +700,7 @@ class PreTrainedTokenizer(object):
if not text.strip():
return []
if not tok_list:
return self._tokenize(text, **kwargs)
return self._tokenize(text)
tokenized_text = []
text_list = [text]
@ -713,7 +716,7 @@ class PreTrainedTokenizer(object):
return list(
itertools.chain.from_iterable(
(
self._tokenize(token, **kwargs) if token not in self.unique_added_tokens_encoder else [token]
self._tokenize(token) if token not in self.unique_added_tokens_encoder else [token]
for token in tokenized_text
)
)
@ -802,6 +805,8 @@ class PreTrainedTokenizer(object):
Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
**kwargs: passed to the `self.tokenize()` method
"""
encoded_inputs = self.encode_plus(
@ -865,6 +870,8 @@ class PreTrainedTokenizer(object):
Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers.
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
return_attention_mask: (optional) Set to False to avoid returning attention mask (default True)
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
@ -895,7 +902,8 @@ class PreTrainedTokenizer(object):
def get_input_ids(text):
if isinstance(text, str):
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
tokens = self.tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
return self.convert_tokens_to_ids(tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
return self.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
@ -1215,6 +1223,10 @@ class PreTrainedTokenizer(object):
return encoded_inputs
def prepare_for_tokenization(self, text, **kwargs):
""" Performs any necessary transformations before tokenization """
return text
def truncate_sequences(
self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy="longest_first", stride=0
):

View File

@ -94,7 +94,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
for key in output_keys:
self.assertIn(key, mono_result[0])
multi_result = nlp(valid_inputs)
multi_result = [nlp(input) for input in valid_inputs]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
@ -129,7 +129,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None]
for tokenizer, model, config in TF_NER_FINETUNED_MODELS:
nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer)
nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
@require_torch
@ -147,7 +147,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None]
for tokenizer, model, config in TF_TEXT_CLASSIF_FINETUNED_MODELS:
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
@require_torch
@ -163,7 +163,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
invalid_inputs = [None]
for tokenizer, model, config in TF_FEATURE_EXTRACT_FINETUNED_MODELS:
nlp = pipeline(task="feature-extraction", model=model, config=config, tokenizer=tokenizer)
nlp = pipeline(task="feature-extraction", model=model, config=config, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
@require_torch
@ -176,14 +176,18 @@ class MonoColumnInputTestCase(unittest.TestCase):
invalid_inputs = [None]
expected_multi_result = [
[
{"score": 0.008698059245944023, "sequence": "<s>My name is John</s>", "token": 610},
{"score": 0.007750614080578089, "sequence": "<s>My name is Chris</s>", "token": 1573},
{"sequence": "<s> My name is:</s>", "score": 0.009954338893294334, "token": 35},
{"sequence": "<s> My name is John</s>", "score": 0.0080940006300807, "token": 610},
],
[
{"score": 0.2721288502216339, "sequence": "<s>The largest city in France is Paris</s>", "token": 2201},
{
"score": 0.19764970242977142,
"sequence": "<s>The largest city in France is Lyon</s>",
"sequence": "<s> The largest city in France is Paris</s>",
"score": 0.3185044229030609,
"token": 2201,
},
{
"sequence": "<s> The largest city in France is Lyon</s>",
"score": 0.21112334728240967,
"token": 12790,
},
],
@ -209,20 +213,24 @@ class MonoColumnInputTestCase(unittest.TestCase):
invalid_inputs = [None]
expected_multi_result = [
[
{"score": 0.008698059245944023, "sequence": "<s>My name is John</s>", "token": 610},
{"score": 0.007750614080578089, "sequence": "<s>My name is Chris</s>", "token": 1573},
{"sequence": "<s> My name is:</s>", "score": 0.009954338893294334, "token": 35},
{"sequence": "<s> My name is John</s>", "score": 0.0080940006300807, "token": 610},
],
[
{"score": 0.2721288502216339, "sequence": "<s>The largest city in France is Paris</s>", "token": 2201},
{
"score": 0.19764970242977142,
"sequence": "<s>The largest city in France is Lyon</s>",
"sequence": "<s> The largest city in France is Paris</s>",
"score": 0.3185044229030609,
"token": 2201,
},
{
"sequence": "<s> The largest city in France is Lyon</s>",
"score": 0.21112334728240967,
"token": 12790,
},
],
]
for tokenizer, model, config in TF_FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model, config=config, tokenizer=tokenizer, topk=2)
nlp = pipeline(task="fill-mask", model=model, config=config, tokenizer=tokenizer, framework="tf", topk=2)
self._test_mono_column_pipeline(
nlp,
valid_inputs,
@ -293,5 +301,5 @@ class MultiColumnInputTestCase(unittest.TestCase):
]
for tokenizer, model, config in TF_QA_FINETUNED_MODELS:
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer)
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer, framework="tf")
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)

View File

@ -204,7 +204,7 @@ class TokenizerTesterMixin:
encoded = tokenizer.encode(text, add_special_tokens=False)
input_encoded = tokenizer.encode(input_text, add_special_tokens=False)
output_encoded = tokenizer.encode(output_text, add_special_tokens=False)
output_encoded = tokenizer.encode(" " + output_text, add_special_tokens=False)
special_token_id = tokenizer.encode(special_token, add_special_tokens=False)
assert encoded == input_encoded + special_token_id + output_encoded
@ -264,7 +264,7 @@ class TokenizerTesterMixin:
seq_1 = "With these inputs."
sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=False)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False)
# Method is implemented (e.g. not GPT-2)
if len(attached_sequences) != 2:
@ -280,7 +280,12 @@ class TokenizerTesterMixin:
num_added_tokens = tokenizer.num_added_tokens()
total_length = len(sequence) + num_added_tokens
information = tokenizer.encode_plus(
seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride, return_overflowing_tokens=True,
seq_0,
max_length=total_length - 2,
add_special_tokens=True,
stride=stride,
return_overflowing_tokens=True,
add_prefix_space=False,
)
truncated_sequence = information["input_ids"]
@ -301,7 +306,7 @@ class TokenizerTesterMixin:
sequence_0_no_special_tokens = tokenizer.encode(seq_0, add_special_tokens=False)
sequence_1_no_special_tokens = tokenizer.encode(seq_1, add_special_tokens=False)
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False)
truncated_second_sequence = tokenizer.build_inputs_with_special_tokens(
tokenizer.encode(seq_0, add_special_tokens=False), tokenizer.encode(seq_1, add_special_tokens=False)[:-2],
)
@ -314,6 +319,7 @@ class TokenizerTesterMixin:
stride=stride,
truncation_strategy="only_second",
return_overflowing_tokens=True,
add_prefix_space=False,
)
information_first_truncated = tokenizer.encode_plus(
seq_0,
@ -323,6 +329,7 @@ class TokenizerTesterMixin:
stride=stride,
truncation_strategy="only_first",
return_overflowing_tokens=True,
add_prefix_space=False,
)
truncated_sequence = information["input_ids"]
@ -342,11 +349,39 @@ class TokenizerTesterMixin:
tokens = tokenizer.tokenize(sequence)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
formatted_input = tokenizer.encode(sequence, add_special_tokens=True)
formatted_input = tokenizer.encode(sequence, add_special_tokens=True, add_prefix_space=False)
self.assertEqual(tokenizer.encode(tokens, add_special_tokens=True), formatted_input)
self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input)
def test_swap_special_token(self):
tokenizer = self.get_tokenizer()
mask = "<mask>"
sequence = "Encode this sequence"
sequence_masked_0 = "Encode <mask> sequence"
sequence_masked_1 = "<mask> this sequence"
# Add tokens so that masked token isn't split
tokenizer.add_tokens(sequence.split())
tokenizer.add_special_tokens({"mask_token": mask})
mask_ind = tokenizer.convert_tokens_to_ids(mask)
encoded = tokenizer.encode(sequence, add_special_tokens=False)
# Test first masked sequence
encoded_masked = tokenizer.encode(sequence_masked_0, add_special_tokens=False)
mask_loc = encoded_masked.index(mask_ind)
encoded_masked[mask_loc] = encoded[mask_loc]
self.assertEqual(encoded_masked, encoded)
# Test second masked sequence
encoded_masked = tokenizer.encode(sequence_masked_1, add_special_tokens=False)
mask_loc = encoded_masked.index(mask_ind)
encoded_masked[mask_loc] = encoded[mask_loc]
self.assertEqual(encoded_masked, encoded)
def test_special_tokens_mask(self):
tokenizer = self.get_tokenizer()
@ -356,7 +391,7 @@ class TokenizerTesterMixin:
# Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, add_special_tokens=True, return_special_tokens_mask=True
sequence_0, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
@ -369,11 +404,10 @@ class TokenizerTesterMixin:
self.assertEqual(encoded_sequence, filtered_sequence)
# Testing inputs pairs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(
sequence_1, add_special_tokens=False
)
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(
sequence_0, sequence_1, add_special_tokens=True, return_special_tokens_mask=True
sequence_0, sequence_1, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False
)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]

View File

@ -110,3 +110,41 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode
def test_space_encoding(self):
tokenizer = self.get_tokenizer()
sequence = "Encode this sequence."
space_encoding = tokenizer.byte_encoder[" ".encode("utf-8")[0]]
# Testing encoder arguments
encoded = tokenizer.encode(sequence, add_special_tokens=False)
first_char = tokenizer.convert_ids_to_tokens(encoded[0])[0]
self.assertNotEqual(first_char, space_encoding)
encoded = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
first_char = tokenizer.convert_ids_to_tokens(encoded[0])[0]
self.assertEqual(first_char, space_encoding)
tokenizer.add_special_tokens({"bos_token": "<s>"})
encoded = tokenizer.encode(sequence, add_special_tokens=True)
first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0]
self.assertEqual(first_char, space_encoding)
# Testing spaces after special tokenss
mask = "<mask>"
tokenizer.add_special_tokens({"mask_token": mask})
mask_ind = tokenizer.convert_tokens_to_ids(mask)
sequence = "Encode <mask> sequence"
sequence_nospace = "Encode <mask>sequence"
encoded = tokenizer.encode(sequence)
mask_loc = encoded.index(mask_ind)
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
self.assertEqual(first_char, space_encoding)
encoded = tokenizer.encode(sequence_nospace)
mask_loc = encoded.index(mask_ind)
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
self.assertNotEqual(first_char, space_encoding)