mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Preserve spaces in GPT-2 tokenizers (#2778)
* Preserve spaces in GPT-2 tokenizers Preserves spaces after special tokens in GPT-2 and inhereted (RoBERTa) tokenizers, enabling correct BPE encoding. Automatically inserts a space in front of first token in encode function when adding special tokens. * Add tokenization preprocessing method * Add framework argument to pipeline factory Also fixes pipeline test issue. Each test input now treated as a distinct sequence.
This commit is contained in:
parent
0ed630f139
commit
f1e8a51f08
@ -1001,6 +1001,7 @@ def pipeline(
|
||||
config: Optional[Union[str, PretrainedConfig]] = None,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
||||
modelcard: Optional[Union[str, ModelCard]] = None,
|
||||
framework: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Pipeline:
|
||||
"""
|
||||
@ -1021,7 +1022,7 @@ def pipeline(
|
||||
if task not in SUPPORTED_TASKS:
|
||||
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
|
||||
|
||||
framework = get_framework(model)
|
||||
framework = framework or get_framework(model)
|
||||
|
||||
targeted_task = SUPPORTED_TASKS[task]
|
||||
task, model_class = targeted_task["impl"], targeted_task[framework]
|
||||
|
@ -191,15 +191,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def _tokenize(self, text, add_prefix_space=False):
|
||||
""" Tokenize a string.
|
||||
Args:
|
||||
- add_prefix_space (boolean, default False):
|
||||
Begin the sentence with at least one space to get invariance to word order in GPT-2 (and RoBERTa) tokenizers.
|
||||
"""
|
||||
if add_prefix_space:
|
||||
text = " " + text
|
||||
|
||||
def _tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = "".join(
|
||||
@ -248,6 +241,11 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
|
||||
return vocab_file, merge_file
|
||||
|
||||
def prepare_for_tokenization(self, text, **kwargs):
|
||||
if "add_prefix_space" in kwargs and kwargs["add_prefix_space"]:
|
||||
return " " + text
|
||||
return text
|
||||
|
||||
|
||||
class GPT2TokenizerFast(PreTrainedTokenizerFast):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
@ -154,3 +154,12 @@ class RobertaTokenizer(GPT2Tokenizer):
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
||||
|
||||
def prepare_for_tokenization(self, text, add_special_tokens=False, **kwargs):
|
||||
if "add_prefix_space" in kwargs:
|
||||
add_prefix_space = kwargs["add_prefix_space"]
|
||||
else:
|
||||
add_prefix_space = add_special_tokens
|
||||
if add_prefix_space and not text[0].isspace():
|
||||
text = " " + text
|
||||
return text
|
||||
|
@ -662,9 +662,12 @@ class PreTrainedTokenizer(object):
|
||||
Take care of added tokens.
|
||||
|
||||
text: The sequence to be encoded.
|
||||
**kwargs: passed to the child `self.tokenize()` method
|
||||
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
|
||||
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
|
||||
**kwargs: passed to the `prepare_for_tokenization` preprocessing method.
|
||||
"""
|
||||
all_special_tokens = self.all_special_tokens
|
||||
text = self.prepare_for_tokenization(text, **kwargs)
|
||||
|
||||
def lowercase_text(t):
|
||||
# convert non-special tokens to lowercase
|
||||
@ -679,7 +682,7 @@ class PreTrainedTokenizer(object):
|
||||
result = []
|
||||
split_text = text.split(tok)
|
||||
for i, sub_text in enumerate(split_text):
|
||||
sub_text = sub_text.strip()
|
||||
sub_text = sub_text.rstrip()
|
||||
if i == 0 and not sub_text:
|
||||
result += [tok]
|
||||
elif i == len(split_text) - 1:
|
||||
@ -697,7 +700,7 @@ class PreTrainedTokenizer(object):
|
||||
if not text.strip():
|
||||
return []
|
||||
if not tok_list:
|
||||
return self._tokenize(text, **kwargs)
|
||||
return self._tokenize(text)
|
||||
|
||||
tokenized_text = []
|
||||
text_list = [text]
|
||||
@ -713,7 +716,7 @@ class PreTrainedTokenizer(object):
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
(
|
||||
self._tokenize(token, **kwargs) if token not in self.unique_added_tokens_encoder else [token]
|
||||
self._tokenize(token) if token not in self.unique_added_tokens_encoder else [token]
|
||||
for token in tokenized_text
|
||||
)
|
||||
)
|
||||
@ -802,6 +805,8 @@ class PreTrainedTokenizer(object):
|
||||
Defaults to False: no padding.
|
||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||
or PyTorch torch.Tensor instead of a list of python integers.
|
||||
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
|
||||
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
|
||||
**kwargs: passed to the `self.tokenize()` method
|
||||
"""
|
||||
encoded_inputs = self.encode_plus(
|
||||
@ -865,6 +870,8 @@ class PreTrainedTokenizer(object):
|
||||
Defaults to False: no padding.
|
||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||
or PyTorch torch.Tensor instead of a list of python integers.
|
||||
add_prefix_space: Only applies to GPT-2 and RoBERTa tokenizers. When `True`, this ensures that the sequence
|
||||
begins with an empty space. False by default except for when using RoBERTa with `add_special_tokens=True`.
|
||||
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
|
||||
return_attention_mask: (optional) Set to False to avoid returning attention mask (default True)
|
||||
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
|
||||
@ -895,7 +902,8 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
def get_input_ids(text):
|
||||
if isinstance(text, str):
|
||||
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
tokens = self.tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
|
||||
return self.convert_tokens_to_ids(tokens)
|
||||
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
|
||||
return self.convert_tokens_to_ids(text)
|
||||
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
||||
@ -1215,6 +1223,10 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def prepare_for_tokenization(self, text, **kwargs):
|
||||
""" Performs any necessary transformations before tokenization """
|
||||
return text
|
||||
|
||||
def truncate_sequences(
|
||||
self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy="longest_first", stride=0
|
||||
):
|
||||
|
@ -94,7 +94,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
for key in output_keys:
|
||||
self.assertIn(key, mono_result[0])
|
||||
|
||||
multi_result = nlp(valid_inputs)
|
||||
multi_result = [nlp(input) for input in valid_inputs]
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], (dict, list))
|
||||
|
||||
@ -129,7 +129,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
||||
invalid_inputs = [None]
|
||||
for tokenizer, model, config in TF_NER_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer)
|
||||
nlp = pipeline(task="ner", model=model, config=config, tokenizer=tokenizer, framework="tf")
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
||||
|
||||
@require_torch
|
||||
@ -147,7 +147,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
||||
invalid_inputs = [None]
|
||||
for tokenizer, model, config in TF_TEXT_CLASSIF_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer)
|
||||
nlp = pipeline(task="sentiment-analysis", model=model, config=config, tokenizer=tokenizer, framework="tf")
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, mandatory_keys)
|
||||
|
||||
@require_torch
|
||||
@ -163,7 +163,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
valid_inputs = ["HuggingFace is solving NLP one commit at a time.", "HuggingFace is based in New-York & Paris"]
|
||||
invalid_inputs = [None]
|
||||
for tokenizer, model, config in TF_FEATURE_EXTRACT_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="feature-extraction", model=model, config=config, tokenizer=tokenizer)
|
||||
nlp = pipeline(task="feature-extraction", model=model, config=config, tokenizer=tokenizer, framework="tf")
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, invalid_inputs, {})
|
||||
|
||||
@require_torch
|
||||
@ -176,14 +176,18 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
invalid_inputs = [None]
|
||||
expected_multi_result = [
|
||||
[
|
||||
{"score": 0.008698059245944023, "sequence": "<s>My name is John</s>", "token": 610},
|
||||
{"score": 0.007750614080578089, "sequence": "<s>My name is Chris</s>", "token": 1573},
|
||||
{"sequence": "<s> My name is:</s>", "score": 0.009954338893294334, "token": 35},
|
||||
{"sequence": "<s> My name is John</s>", "score": 0.0080940006300807, "token": 610},
|
||||
],
|
||||
[
|
||||
{"score": 0.2721288502216339, "sequence": "<s>The largest city in France is Paris</s>", "token": 2201},
|
||||
{
|
||||
"score": 0.19764970242977142,
|
||||
"sequence": "<s>The largest city in France is Lyon</s>",
|
||||
"sequence": "<s> The largest city in France is Paris</s>",
|
||||
"score": 0.3185044229030609,
|
||||
"token": 2201,
|
||||
},
|
||||
{
|
||||
"sequence": "<s> The largest city in France is Lyon</s>",
|
||||
"score": 0.21112334728240967,
|
||||
"token": 12790,
|
||||
},
|
||||
],
|
||||
@ -209,20 +213,24 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
invalid_inputs = [None]
|
||||
expected_multi_result = [
|
||||
[
|
||||
{"score": 0.008698059245944023, "sequence": "<s>My name is John</s>", "token": 610},
|
||||
{"score": 0.007750614080578089, "sequence": "<s>My name is Chris</s>", "token": 1573},
|
||||
{"sequence": "<s> My name is:</s>", "score": 0.009954338893294334, "token": 35},
|
||||
{"sequence": "<s> My name is John</s>", "score": 0.0080940006300807, "token": 610},
|
||||
],
|
||||
[
|
||||
{"score": 0.2721288502216339, "sequence": "<s>The largest city in France is Paris</s>", "token": 2201},
|
||||
{
|
||||
"score": 0.19764970242977142,
|
||||
"sequence": "<s>The largest city in France is Lyon</s>",
|
||||
"sequence": "<s> The largest city in France is Paris</s>",
|
||||
"score": 0.3185044229030609,
|
||||
"token": 2201,
|
||||
},
|
||||
{
|
||||
"sequence": "<s> The largest city in France is Lyon</s>",
|
||||
"score": 0.21112334728240967,
|
||||
"token": 12790,
|
||||
},
|
||||
],
|
||||
]
|
||||
for tokenizer, model, config in TF_FILL_MASK_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="fill-mask", model=model, config=config, tokenizer=tokenizer, topk=2)
|
||||
nlp = pipeline(task="fill-mask", model=model, config=config, tokenizer=tokenizer, framework="tf", topk=2)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp,
|
||||
valid_inputs,
|
||||
@ -293,5 +301,5 @@ class MultiColumnInputTestCase(unittest.TestCase):
|
||||
]
|
||||
|
||||
for tokenizer, model, config in TF_QA_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer)
|
||||
nlp = pipeline(task="question-answering", model=model, config=config, tokenizer=tokenizer, framework="tf")
|
||||
self._test_multicolumn_pipeline(nlp, valid_samples, invalid_samples, mandatory_output_keys)
|
||||
|
@ -204,7 +204,7 @@ class TokenizerTesterMixin:
|
||||
encoded = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
input_encoded = tokenizer.encode(input_text, add_special_tokens=False)
|
||||
output_encoded = tokenizer.encode(output_text, add_special_tokens=False)
|
||||
output_encoded = tokenizer.encode(" " + output_text, add_special_tokens=False)
|
||||
special_token_id = tokenizer.encode(special_token, add_special_tokens=False)
|
||||
assert encoded == input_encoded + special_token_id + output_encoded
|
||||
|
||||
@ -264,7 +264,7 @@ class TokenizerTesterMixin:
|
||||
seq_1 = "With these inputs."
|
||||
|
||||
sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=False)
|
||||
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
||||
attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False)
|
||||
|
||||
# Method is implemented (e.g. not GPT-2)
|
||||
if len(attached_sequences) != 2:
|
||||
@ -280,7 +280,12 @@ class TokenizerTesterMixin:
|
||||
num_added_tokens = tokenizer.num_added_tokens()
|
||||
total_length = len(sequence) + num_added_tokens
|
||||
information = tokenizer.encode_plus(
|
||||
seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride, return_overflowing_tokens=True,
|
||||
seq_0,
|
||||
max_length=total_length - 2,
|
||||
add_special_tokens=True,
|
||||
stride=stride,
|
||||
return_overflowing_tokens=True,
|
||||
add_prefix_space=False,
|
||||
)
|
||||
|
||||
truncated_sequence = information["input_ids"]
|
||||
@ -301,7 +306,7 @@ class TokenizerTesterMixin:
|
||||
sequence_0_no_special_tokens = tokenizer.encode(seq_0, add_special_tokens=False)
|
||||
sequence_1_no_special_tokens = tokenizer.encode(seq_1, add_special_tokens=False)
|
||||
|
||||
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
||||
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False)
|
||||
truncated_second_sequence = tokenizer.build_inputs_with_special_tokens(
|
||||
tokenizer.encode(seq_0, add_special_tokens=False), tokenizer.encode(seq_1, add_special_tokens=False)[:-2],
|
||||
)
|
||||
@ -314,6 +319,7 @@ class TokenizerTesterMixin:
|
||||
stride=stride,
|
||||
truncation_strategy="only_second",
|
||||
return_overflowing_tokens=True,
|
||||
add_prefix_space=False,
|
||||
)
|
||||
information_first_truncated = tokenizer.encode_plus(
|
||||
seq_0,
|
||||
@ -323,6 +329,7 @@ class TokenizerTesterMixin:
|
||||
stride=stride,
|
||||
truncation_strategy="only_first",
|
||||
return_overflowing_tokens=True,
|
||||
add_prefix_space=False,
|
||||
)
|
||||
|
||||
truncated_sequence = information["input_ids"]
|
||||
@ -342,11 +349,39 @@ class TokenizerTesterMixin:
|
||||
|
||||
tokens = tokenizer.tokenize(sequence)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
formatted_input = tokenizer.encode(sequence, add_special_tokens=True)
|
||||
formatted_input = tokenizer.encode(sequence, add_special_tokens=True, add_prefix_space=False)
|
||||
|
||||
self.assertEqual(tokenizer.encode(tokens, add_special_tokens=True), formatted_input)
|
||||
self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input)
|
||||
|
||||
def test_swap_special_token(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
mask = "<mask>"
|
||||
sequence = "Encode this sequence"
|
||||
sequence_masked_0 = "Encode <mask> sequence"
|
||||
sequence_masked_1 = "<mask> this sequence"
|
||||
|
||||
# Add tokens so that masked token isn't split
|
||||
tokenizer.add_tokens(sequence.split())
|
||||
tokenizer.add_special_tokens({"mask_token": mask})
|
||||
mask_ind = tokenizer.convert_tokens_to_ids(mask)
|
||||
encoded = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
|
||||
# Test first masked sequence
|
||||
encoded_masked = tokenizer.encode(sequence_masked_0, add_special_tokens=False)
|
||||
mask_loc = encoded_masked.index(mask_ind)
|
||||
encoded_masked[mask_loc] = encoded[mask_loc]
|
||||
|
||||
self.assertEqual(encoded_masked, encoded)
|
||||
|
||||
# Test second masked sequence
|
||||
encoded_masked = tokenizer.encode(sequence_masked_1, add_special_tokens=False)
|
||||
mask_loc = encoded_masked.index(mask_ind)
|
||||
encoded_masked[mask_loc] = encoded[mask_loc]
|
||||
|
||||
self.assertEqual(encoded_masked, encoded)
|
||||
|
||||
def test_special_tokens_mask(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
@ -356,7 +391,7 @@ class TokenizerTesterMixin:
|
||||
# Testing single inputs
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(
|
||||
sequence_0, add_special_tokens=True, return_special_tokens_mask=True
|
||||
sequence_0, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False
|
||||
)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
@ -369,11 +404,10 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
||||
# Testing inputs pairs
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(
|
||||
sequence_1, add_special_tokens=False
|
||||
)
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
||||
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(
|
||||
sequence_0, sequence_1, add_special_tokens=True, return_special_tokens_mask=True
|
||||
sequence_0, sequence_1, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False
|
||||
)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
|
@ -110,3 +110,41 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
assert encoded_sentence == encoded_text_from_decode
|
||||
assert encoded_pair == encoded_pair_from_decode
|
||||
|
||||
def test_space_encoding(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sequence = "Encode this sequence."
|
||||
space_encoding = tokenizer.byte_encoder[" ".encode("utf-8")[0]]
|
||||
|
||||
# Testing encoder arguments
|
||||
encoded = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
first_char = tokenizer.convert_ids_to_tokens(encoded[0])[0]
|
||||
self.assertNotEqual(first_char, space_encoding)
|
||||
|
||||
encoded = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
|
||||
first_char = tokenizer.convert_ids_to_tokens(encoded[0])[0]
|
||||
self.assertEqual(first_char, space_encoding)
|
||||
|
||||
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
||||
encoded = tokenizer.encode(sequence, add_special_tokens=True)
|
||||
first_char = tokenizer.convert_ids_to_tokens(encoded[1])[0]
|
||||
self.assertEqual(first_char, space_encoding)
|
||||
|
||||
# Testing spaces after special tokenss
|
||||
mask = "<mask>"
|
||||
tokenizer.add_special_tokens({"mask_token": mask})
|
||||
mask_ind = tokenizer.convert_tokens_to_ids(mask)
|
||||
|
||||
sequence = "Encode <mask> sequence"
|
||||
sequence_nospace = "Encode <mask>sequence"
|
||||
|
||||
encoded = tokenizer.encode(sequence)
|
||||
mask_loc = encoded.index(mask_ind)
|
||||
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
|
||||
self.assertEqual(first_char, space_encoding)
|
||||
|
||||
encoded = tokenizer.encode(sequence_nospace)
|
||||
mask_loc = encoded.index(mask_ind)
|
||||
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
|
||||
self.assertNotEqual(first_char, space_encoding)
|
||||
|
Loading…
Reference in New Issue
Block a user