mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
XLM-R Tokenizer now passes common tests + Integration tests (#3198)
* XLM-R now passes common tests + Integration tests * Correct mask index * Model input names * Style * Remove text preprocessing * Unneccessary import
This commit is contained in:
parent
292186a3e7
commit
d6afbd323d
@ -104,6 +104,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||||||
vocab_files_names = VOCAB_FILES_NAMES
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
model_input_names = ["attention_mask"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -155,7 +156,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||||||
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
|
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
|
||||||
self.fairseq_offset = 1
|
self.fairseq_offset = 1
|
||||||
|
|
||||||
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
|
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + self.fairseq_offset
|
||||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
@ -261,7 +262,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return len(self.sp_model) + len(self.fairseq_tokens_to_ids)
|
return len(self.sp_model) + self.fairseq_offset + 1 # Add the <mask> token
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||||
@ -275,7 +276,10 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
|
|||||||
""" Converts a token (str) in an id using the vocab. """
|
""" Converts a token (str) in an id using the vocab. """
|
||||||
if token in self.fairseq_tokens_to_ids:
|
if token in self.fairseq_tokens_to_ids:
|
||||||
return self.fairseq_tokens_to_ids[token]
|
return self.fairseq_tokens_to_ids[token]
|
||||||
return self.sp_model.PieceToId(token) + self.fairseq_offset
|
spm_id = self.sp_model.PieceToId(token)
|
||||||
|
|
||||||
|
# Need to return unknown token if the SP model returned 0
|
||||||
|
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
|
||||||
|
|
||||||
def _convert_id_to_token(self, index):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
@ -14,14 +14,113 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.tokenization_xlm_roberta import XLMRobertaTokenizer
|
from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer
|
||||||
|
|
||||||
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
from .utils import slow
|
from .utils import slow
|
||||||
|
|
||||||
|
|
||||||
class XLMRobertaTokenizationIntegrationTest(unittest.TestCase):
|
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||||
|
|
||||||
|
|
||||||
|
class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
|
tokenizer_class = XLMRobertaTokenizer
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
# We have a SentencePiece fixture for testing
|
||||||
|
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||||
|
tokenizer.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
return XLMRobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def get_input_output_texts(self):
|
||||||
|
input_text = "This is a test"
|
||||||
|
output_text = "This is a test"
|
||||||
|
return input_text, output_text
|
||||||
|
|
||||||
|
def test_full_tokenizer(self):
|
||||||
|
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("This is a test")
|
||||||
|
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
tokenizer.convert_tokens_to_ids(tokens),
|
||||||
|
[value + tokenizer.fairseq_offset for value in [285, 46, 10, 170, 382]],
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||||
|
self.assertListEqual(
|
||||||
|
tokens,
|
||||||
|
[
|
||||||
|
SPIECE_UNDERLINE + "I",
|
||||||
|
SPIECE_UNDERLINE + "was",
|
||||||
|
SPIECE_UNDERLINE + "b",
|
||||||
|
"or",
|
||||||
|
"n",
|
||||||
|
SPIECE_UNDERLINE + "in",
|
||||||
|
SPIECE_UNDERLINE + "",
|
||||||
|
"9",
|
||||||
|
"2",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
",",
|
||||||
|
SPIECE_UNDERLINE + "and",
|
||||||
|
SPIECE_UNDERLINE + "this",
|
||||||
|
SPIECE_UNDERLINE + "is",
|
||||||
|
SPIECE_UNDERLINE + "f",
|
||||||
|
"al",
|
||||||
|
"s",
|
||||||
|
"é",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
self.assertListEqual(
|
||||||
|
ids,
|
||||||
|
[
|
||||||
|
value + tokenizer.fairseq_offset
|
||||||
|
for value in [8, 21, 84, 55, 24, 19, 7, 2, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 2, 4]
|
||||||
|
# ^ unk: 2 + 1 = 3 unk: 2 + 1 = 3 ^
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||||
|
self.assertListEqual(
|
||||||
|
back_tokens,
|
||||||
|
[
|
||||||
|
SPIECE_UNDERLINE + "I",
|
||||||
|
SPIECE_UNDERLINE + "was",
|
||||||
|
SPIECE_UNDERLINE + "b",
|
||||||
|
"or",
|
||||||
|
"n",
|
||||||
|
SPIECE_UNDERLINE + "in",
|
||||||
|
SPIECE_UNDERLINE + "",
|
||||||
|
"<unk>",
|
||||||
|
"2",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
"0",
|
||||||
|
",",
|
||||||
|
SPIECE_UNDERLINE + "and",
|
||||||
|
SPIECE_UNDERLINE + "this",
|
||||||
|
SPIECE_UNDERLINE + "is",
|
||||||
|
SPIECE_UNDERLINE + "f",
|
||||||
|
"al",
|
||||||
|
"s",
|
||||||
|
"<unk>",
|
||||||
|
".",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tokenization_base_easy_symbols(self):
|
def test_tokenization_base_easy_symbols(self):
|
||||||
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||||
@ -89,9 +188,11 @@ class XLMRobertaTokenizationIntegrationTest(unittest.TestCase):
|
|||||||
1098,
|
1098,
|
||||||
29367,
|
29367,
|
||||||
47,
|
47,
|
||||||
4426,
|
# 4426, # What fairseq tokenizes from "<unk>": "_<"
|
||||||
3678,
|
# 3678, # What fairseq tokenizes from "<unk>": "unk"
|
||||||
2740,
|
# 2740, # What fairseq tokenizes from "<unk>": ">"
|
||||||
|
3, # What we tokenize from "<unk>": "<unk>"
|
||||||
|
6, # Residue from the tokenization: an extra sentencepiece underline
|
||||||
4,
|
4,
|
||||||
6044,
|
6044,
|
||||||
237,
|
237,
|
||||||
|
Loading…
Reference in New Issue
Block a user