Merge pull request #1092 from shijie-wu/xlm-tokenization

Added cleaned configuration properties for tokenizer with serialization - improve tokenization of XLM
This commit is contained in:
Thomas Wolf 2019-08-30 23:15:40 +02:00 committed by GitHub
commit d2f21f08f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 706 additions and 107 deletions

View File

@ -44,6 +44,8 @@ XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin", 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json",
} }
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
@ -54,6 +56,8 @@ XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
} }
@ -114,6 +118,7 @@ class XLMConfig(PretrainedConfig):
causal=False, causal=False,
asm=False, asm=False,
n_langs=1, n_langs=1,
use_lang_emb=True,
max_position_embeddings=512, max_position_embeddings=512,
embed_init_std=2048 ** -0.5, embed_init_std=2048 ** -0.5,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
@ -157,6 +162,7 @@ class XLMConfig(PretrainedConfig):
self.causal = causal self.causal = causal
self.asm = asm self.asm = asm
self.n_langs = n_langs self.n_langs = n_langs
self.use_lang_emb = use_lang_emb
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.bos_index = bos_index self.bos_index = bos_index
self.eos_index = eos_index self.eos_index = eos_index
@ -488,7 +494,7 @@ class XLMModel(XLMPreTrainedModel):
""" """
ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output', ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output',
'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', 'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads',
'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'hidden_dim', 'dropout', 'attention_dropout', 'asm',
'asm_cutoffs', 'asm_div_value'] 'asm_cutoffs', 'asm_div_value']
@ -507,6 +513,7 @@ class XLMModel(XLMPreTrainedModel):
# dictionary / languages # dictionary / languages
self.n_langs = config.n_langs self.n_langs = config.n_langs
self.use_lang_emb = config.use_lang_emb
self.n_words = config.n_words self.n_words = config.n_words
self.eos_index = config.eos_index self.eos_index = config.eos_index
self.pad_index = config.pad_index self.pad_index = config.pad_index
@ -529,7 +536,7 @@ class XLMModel(XLMPreTrainedModel):
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim) self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings: if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1: if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim) self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index) self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
@ -628,7 +635,7 @@ class XLMModel(XLMPreTrainedModel):
# embeddings # embeddings
tensor = self.embeddings(input_ids) tensor = self.embeddings(input_ids)
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor) tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
if langs is not None: if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(token_type_ids)

View File

@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return self.tokenizer_class.from_pretrained(self.tmpdirname) return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running" input_text = u"UNwant\u00E9d,running"

View File

@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer tokenizer_class = DistilBertTokenizer
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname) return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def test_sequence_builders(self): def test_sequence_builders(self):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

View File

@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) kwargs.update(self.special_tokens_map)
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"

View File

@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname) return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"

View File

@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) kwargs.update(self.special_tokens_map)
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"

View File

@ -49,23 +49,32 @@ class CommonTestCases:
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
raise NotImplementedError raise NotImplementedError
def get_input_output_texts(self): def get_input_output_texts(self):
raise NotImplementedError raise NotImplementedError
def test_save_and_load_tokenizer(self): def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
self.assertNotEqual(tokenizer.max_len, 42)
# Now let's start the test
tokenizer = self.get_tokenizer(max_len=42)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname) tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
self.assertListEqual(before_tokens, after_tokens) self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42)
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43)
self.assertEqual(tokenizer.max_len, 43)
def test_pickle_tokenizer(self): def test_pickle_tokenizer(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()

View File

@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True) kwargs['lower_case'] = True
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"<unk> UNwanted , running" input_text = u"<unk> UNwanted , running"

View File

@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
with open(self.merges_file, "w") as fp: with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return XLMTokenizer.from_pretrained(self.tmpdirname) return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"lower newer" input_text = u"lower newer"

View File

@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self): def get_tokenizer(self, **kwargs):
return XLNetTokenizer.from_pretrained(self.tmpdirname) return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"This is a test" input_text = u"This is a test"

View File

@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-cased-finetuned-mrpc': 512, 'bert-base-cased-finetuned-mrpc': 512,
} }
PRETRAINED_INIT_CONFIGURATION = {
'bert-base-uncased': {'do_lower_case': True},
'bert-large-uncased': {'do_lower_case': True},
'bert-base-cased': {'do_lower_case': False},
'bert-large-cased': {'do_lower_case': False},
'bert-base-multilingual-uncased': {'do_lower_case': True},
'bert-base-multilingual-cased': {'do_lower_case': False},
'bert-base-chinese': {'do_lower_case': False},
'bert-base-german-cased': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
}
def load_vocab(vocab_file): def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary.""" """Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict() vocab = collections.OrderedDict()
@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
@ -202,24 +220,6 @@ class BertTokenizer(PreTrainedTokenizer):
index += 1 index += 1
return (vocab_file,) return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
"""
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior.")
kwargs['do_lower_case'] = False
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
class BasicTokenizer(object): class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

View File

@ -95,7 +95,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
# in a library like ours, at all. # in a library like ours, at all.
vocab_dict = torch.load(pretrained_vocab_file) vocab_dict = torch.load(pretrained_vocab_file)
for key, value in vocab_dict.items(): for key, value in vocab_dict.items():
self.__dict__[key] = value if key not in self.__dict__:
self.__dict__[key] = value
if vocab_file is not None: if vocab_file is not None:
self.build_vocab() self.build_vocab()

View File

@ -20,6 +20,7 @@ import logging
import os import os
import json import json
import six import six
import copy
from io import open from io import open
from .file_utils import cached_path from .file_utils import cached_path
@ -28,6 +29,7 @@ logger = logging.getLogger(__name__)
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
ADDED_TOKENS_FILE = 'added_tokens.json' ADDED_TOKENS_FILE = 'added_tokens.json'
TOKENIZER_CONFIG_FILE = 'tokenizer_config.json'
class PreTrainedTokenizer(object): class PreTrainedTokenizer(object):
""" Base class for all tokenizers. """ Base class for all tokenizers.
@ -40,6 +42,7 @@ class PreTrainedTokenizer(object):
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file. - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size. - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
Parameters: Parameters:
@ -61,6 +64,7 @@ class PreTrainedTokenizer(object):
""" """
vocab_files_names = {} vocab_files_names = {}
pretrained_vocab_files_map = {} pretrained_vocab_files_map = {}
pretrained_init_configuration = {}
max_model_input_sizes = {} max_model_input_sizes = {}
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
@ -166,12 +170,15 @@ class PreTrainedTokenizer(object):
self._additional_special_tokens = [] self._additional_special_tokens = []
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.max_len_single_sentence = self.max_len
self.max_len_sentences_pair = self.max_len
# Added tokens
self.added_tokens_encoder = {} self.added_tokens_encoder = {}
self.added_tokens_decoder = {} self.added_tokens_decoder = {}
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
self.init_inputs = ()
self.init_kwargs = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == 'additional_special_tokens': if key == 'additional_special_tokens':
@ -231,17 +238,20 @@ class PreTrainedTokenizer(object):
@classmethod @classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False) force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None) proxies = kwargs.pop('proxies', None)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
init_configuration = {}
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
# Get the vocabulary from AWS S3 bucket # Get the vocabulary from AWS S3 bucket
for file_id, map_list in cls.pretrained_vocab_files_map.items(): for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path] vocab_files[file_id] = map_list[pretrained_model_name_or_path]
if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
else: else:
# Get the vocabulary from local files # Get the vocabulary from local files
logger.info( logger.info(
@ -264,15 +274,17 @@ class PreTrainedTokenizer(object):
vocab_files[file_id] = full_file_name vocab_files[file_id] = full_file_name
# Look for the additional tokens files # Look for the additional tokens files
all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE} 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE,
'tokenizer_config_file': TOKENIZER_CONFIG_FILE,
}
# If a path to a file was provided, get the parent directory # If a path to a file was provided, get the parent directory
saved_directory = pretrained_model_name_or_path saved_directory = pretrained_model_name_or_path
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory): if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
saved_directory = os.path.dirname(saved_directory) saved_directory = os.path.dirname(saved_directory)
for file_id, file_name in all_vocab_files_names.items(): for file_id, file_name in additional_files_names.items():
full_file_name = os.path.join(saved_directory, file_name) full_file_name = os.path.join(saved_directory, file_name)
if not os.path.exists(full_file_name): if not os.path.exists(full_file_name):
logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
@ -315,28 +327,46 @@ class PreTrainedTokenizer(object):
logger.info("loading file {} from cache at {}".format( logger.info("loading file {} from cache at {}".format(
file_path, resolved_vocab_files[file_id])) file_path, resolved_vocab_files[file_id]))
# Prepare tokenizer initialization kwargs
# Did we saved some inputs and kwargs to reload ?
tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
if tokenizer_config_file is not None:
init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
saved_init_inputs = init_kwargs.pop('init_inputs', ())
if not init_inputs:
init_inputs = saved_init_inputs
else:
init_kwargs = init_configuration
# Update with newly provided kwargs
init_kwargs.update(kwargs)
# Set max length if needed # Set max length if needed
if pretrained_model_name_or_path in cls.max_model_input_sizes: if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer # if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings # wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
if max_len is not None and isinstance(max_len, (int, float)): if max_len is not None and isinstance(max_len, (int, float)):
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
# Merge resolved_vocab_files arguments in kwargs. # Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
for args_name, file_path in resolved_vocab_files.items(): for args_name, file_path in resolved_vocab_files.items():
if args_name not in kwargs: if args_name not in init_kwargs:
kwargs[args_name] = file_path init_kwargs[args_name] = file_path
if special_tokens_map_file is not None: if special_tokens_map_file is not None:
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8")) special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
for key, value in special_tokens_map.items(): for key, value in special_tokens_map.items():
if key not in kwargs: if key not in init_kwargs:
kwargs[key] = value init_kwargs[key] = value
# Instantiate tokenizer. # Instantiate tokenizer.
tokenizer = cls(*inputs, **kwargs) tokenizer = cls(*init_inputs, **init_kwargs)
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
tokenizer.init_inputs = init_inputs
tokenizer.init_kwargs = init_kwargs
# Add supplementary tokens. # Add supplementary tokens.
if added_tokens_file is not None: if added_tokens_file is not None:
@ -349,8 +379,13 @@ class PreTrainedTokenizer(object):
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save the tokenizer vocabulary files (with added tokens) and the """ Save the tokenizer vocabulary files together with:
special-tokens-to-class-attributes-mapping to a directory. - added tokens,
- special-tokens-to-class-attributes-mapping,
- tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
This won't save modifications other than (added tokens and special token mapping) you may have
applied to the tokenizer after the instantion (e.g. modifying tokenizer.do_lower_case after creation).
This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method. This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
""" """
@ -360,6 +395,15 @@ class PreTrainedTokenizer(object):
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
tokenizer_config = copy.deepcopy(self.init_kwargs)
tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys():
tokenizer_config.pop(file_id, None)
with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
with open(special_tokens_map_file, 'w', encoding='utf-8') as f: with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
@ -566,7 +610,7 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False): def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs):
""" """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
@ -577,15 +621,16 @@ class PreTrainedTokenizer(object):
text_pair: Optional second sequence to be encoded. text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model. to their model.
**kwargs: passed to the `self.tokenize()` method
""" """
if text_pair is None: if text_pair is None:
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text))) return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
else: else:
return self.convert_tokens_to_ids(self.tokenize(text)) return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)] first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)

View File

@ -20,8 +20,12 @@ import json
import logging import logging
import os import os
import re import re
import sys
import unicodedata
from io import open from io import open
import sacremoses as sm
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
from .tokenization_bert import BasicTokenizer from .tokenization_bert import BasicTokenizer
@ -43,6 +47,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json", 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
}, },
'merges_file': 'merges_file':
{ {
@ -54,6 +60,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt", 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
}, },
} }
@ -66,6 +74,342 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-mlm-xnli15-1024': 512, 'xlm-mlm-xnli15-1024': 512,
'xlm-clm-enfr-1024': 512, 'xlm-clm-enfr-1024': 512,
'xlm-clm-ende-1024': 512, 'xlm-clm-ende-1024': 512,
'xlm-mlm-17-1280': 512,
'xlm-mlm-100-1280': 512,
}
PRETRAINED_INIT_CONFIGURATION = {
'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True},
'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "de",
"1": "en"},
"lang2id": { "de": 0,
"en": 1 }},
'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "en",
"1": "fr"},
"lang2id": { "en": 0,
"fr": 1 }},
'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "en",
"1": "ro"},
"lang2id": { "en": 0,
"ro": 1 }},
'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "ar",
"1": "bg",
"2": "de",
"3": "el",
"4": "en",
"5": "es",
"6": "fr",
"7": "hi",
"8": "ru",
"9": "sw",
"10": "th",
"11": "tr",
"12": "ur",
"13": "vi",
"14": "zh"},
"lang2id": { "ar": 0,
"bg": 1,
"de": 2,
"el": 3,
"en": 4,
"es": 5,
"fr": 6,
"hi": 7,
"ru": 8,
"sw": 9,
"th": 10,
"tr": 11,
"ur": 12,
"vi": 13,
"zh": 14 }},
'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "ar",
"1": "bg",
"2": "de",
"3": "el",
"4": "en",
"5": "es",
"6": "fr",
"7": "hi",
"8": "ru",
"9": "sw",
"10": "th",
"11": "tr",
"12": "ur",
"13": "vi",
"14": "zh"},
"lang2id": { "ar": 0,
"bg": 1,
"de": 2,
"el": 3,
"en": 4,
"es": 5,
"fr": 6,
"hi": 7,
"ru": 8,
"sw": 9,
"th": 10,
"tr": 11,
"ur": 12,
"vi": 13,
"zh": 14 }},
'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "en",
"1": "fr"},
"lang2id": { "en": 0,
"fr": 1 }},
'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "de",
"1": "en"},
"lang2id": { "de": 0,
"en": 1 }},
'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False,
"id2lang": {
"0": "ar",
"1": "de",
"2": "en",
"3": "es",
"4": "fr",
"5": "hi",
"6": "it",
"7": "ja",
"8": "ko",
"9": "nl",
"10": "pl",
"11": "pt",
"12": "ru",
"13": "sv",
"14": "tr",
"15": "vi",
"16": "zh"
},
"lang2id": {
"ar": 0,
"de": 1,
"en": 2,
"es": 3,
"fr": 4,
"hi": 5,
"it": 6,
"ja": 7,
"ko": 8,
"nl": 9,
"pl": 10,
"pt": 11,
"ru": 12,
"sv": 13,
"tr": 14,
"vi": 15,
"zh": 16}},
'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False,
"id2lang": {
"0": "af",
"1": "als",
"2": "am",
"3": "an",
"4": "ang",
"5": "ar",
"6": "arz",
"7": "ast",
"8": "az",
"9": "bar",
"10": "be",
"11": "bg",
"12": "bn",
"13": "br",
"14": "bs",
"15": "ca",
"16": "ceb",
"17": "ckb",
"18": "cs",
"19": "cy",
"20": "da",
"21": "de",
"22": "el",
"23": "en",
"24": "eo",
"25": "es",
"26": "et",
"27": "eu",
"28": "fa",
"29": "fi",
"30": "fr",
"31": "fy",
"32": "ga",
"33": "gan",
"34": "gl",
"35": "gu",
"36": "he",
"37": "hi",
"38": "hr",
"39": "hu",
"40": "hy",
"41": "ia",
"42": "id",
"43": "is",
"44": "it",
"45": "ja",
"46": "jv",
"47": "ka",
"48": "kk",
"49": "kn",
"50": "ko",
"51": "ku",
"52": "la",
"53": "lb",
"54": "lt",
"55": "lv",
"56": "mk",
"57": "ml",
"58": "mn",
"59": "mr",
"60": "ms",
"61": "my",
"62": "nds",
"63": "ne",
"64": "nl",
"65": "nn",
"66": "no",
"67": "oc",
"68": "pl",
"69": "pt",
"70": "ro",
"71": "ru",
"72": "scn",
"73": "sco",
"74": "sh",
"75": "si",
"76": "simple",
"77": "sk",
"78": "sl",
"79": "sq",
"80": "sr",
"81": "sv",
"82": "sw",
"83": "ta",
"84": "te",
"85": "th",
"86": "tl",
"87": "tr",
"88": "tt",
"89": "uk",
"90": "ur",
"91": "uz",
"92": "vi",
"93": "war",
"94": "wuu",
"95": "yi",
"96": "zh",
"97": "zh_classical",
"98": "zh_min_nan",
"99": "zh_yue"
},
"lang2id": {
"af": 0,
"als": 1,
"am": 2,
"an": 3,
"ang": 4,
"ar": 5,
"arz": 6,
"ast": 7,
"az": 8,
"bar": 9,
"be": 10,
"bg": 11,
"bn": 12,
"br": 13,
"bs": 14,
"ca": 15,
"ceb": 16,
"ckb": 17,
"cs": 18,
"cy": 19,
"da": 20,
"de": 21,
"el": 22,
"en": 23,
"eo": 24,
"es": 25,
"et": 26,
"eu": 27,
"fa": 28,
"fi": 29,
"fr": 30,
"fy": 31,
"ga": 32,
"gan": 33,
"gl": 34,
"gu": 35,
"he": 36,
"hi": 37,
"hr": 38,
"hu": 39,
"hy": 40,
"ia": 41,
"id": 42,
"is": 43,
"it": 44,
"ja": 45,
"jv": 46,
"ka": 47,
"kk": 48,
"kn": 49,
"ko": 50,
"ku": 51,
"la": 52,
"lb": 53,
"lt": 54,
"lv": 55,
"mk": 56,
"ml": 57,
"mn": 58,
"mr": 59,
"ms": 60,
"my": 61,
"nds": 62,
"ne": 63,
"nl": 64,
"nn": 65,
"no": 66,
"oc": 67,
"pl": 68,
"pt": 69,
"ro": 70,
"ru": 71,
"scn": 72,
"sco": 73,
"sh": 74,
"si": 75,
"simple": 76,
"sk": 77,
"sl": 78,
"sq": 79,
"sr": 80,
"sv": 81,
"sw": 82,
"ta": 83,
"te": 84,
"th": 85,
"tl": 86,
"tr": 87,
"tt": 88,
"uk": 89,
"ur": 90,
"uz": 91,
"vi": 92,
"war": 93,
"wuu": 94,
"yi": 95,
"zh": 96,
"zh_classical": 97,
"zh_min_nan": 98,
"zh_yue": 99
}},
} }
def get_pairs(word): def get_pairs(word):
@ -80,62 +424,145 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
def text_standardize(text):
def lowercase_and_remove_accent(text):
""" """
fixes some issues the spacy tokenizer had on books corpus Lowercase and strips accents from a piece of text based on
also does some whitespace standardization https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
""" """
text = text.replace('', '-') text = ' '.join(text)
text = text.replace('', '-') text = text.lower()
text = text.replace('', '-') text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output).lower().split(' ')
def replace_unicode_punct(text):
'''
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
'''
text = text.replace('', ',')
text = re.sub(r'\s*', '. ', text)
text = text.replace('', ',')
text = text.replace('', '"')
text = text.replace('', '"')
text = text.replace('', ':')
text = text.replace('', ':')
text = text.replace('', '?')
text = text.replace('', '"')
text = text.replace('', '"')
text = text.replace('', ')')
text = text.replace('', '!')
text = text.replace('', '(')
text = text.replace('', ';')
text = text.replace('', '"')
text = text.replace('', '"')
text = text.replace('', '"')
text = text.replace('', '0')
text = text.replace('', '3')
text = text.replace('', '2')
text = text.replace('', '5')
text = text.replace('', '6')
text = text.replace('', '9')
text = text.replace('', '7')
text = text.replace('', '8')
text = text.replace('', '4')
text = re.sub(r'\s*', '. ', text)
text = text.replace('', '~')
text = text.replace('', '\'')
text = text.replace('', '...') text = text.replace('', '...')
text = text.replace('´', "'") text = text.replace('', '-')
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) text = text.replace('', '<')
text = re.sub(r'\s*\n\s*', ' \n ', text) text = text.replace('', '>')
text = re.sub(r'[^\S\n]+', ' ', text) text = text.replace('', '[')
return text.strip() text = text.replace('', ']')
text = text.replace('', '%')
return text
def remove_non_printing_char(text):
'''
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
'''
output = []
for char in text:
cat = unicodedata.category(char)
if cat.startswith('C'):
continue
output.append(char)
return "".join(output)
def romanian_preprocessing(text):
'''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`'''
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma
text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma
text = text.replace("\u0102", "A").replace("\u0103", "a")
text = text.replace("\u00C2", "A").replace("\u00E2", "a")
text = text.replace("\u00CE", "I").replace("\u00EE", "i")
return text
class XLMTokenizer(PreTrainedTokenizer): class XLMTokenizer(PreTrainedTokenizer):
""" """
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities: BPE tokenizer for XLM
- lower case all inputs - Moses preprocessing & tokenization for most supported languages
- uses `SpaCy tokenizer <https://spacy.io/api/tokenizer/>`_ and \ - Language specific tokenization for Chinese (Jieba), Japanese (KyTea) and Thai (PyThaiNLP)
`ftfy <https://ftfy.readthedocs.io/en/latest/>`_ for pre-BPE tokenization if they are installed, \
fallback to BERT's BasicTokenizer if not. - (optionally) lower case & normalize all inputs text
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \ - argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
(ex: "__classify__") to a vocabulary. (ex: "__classify__") to a vocabulary
- `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies)
- `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies)
- `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies)
""" """
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>", def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
sep_token="</s>", pad_token="<pad>", cls_token="</s>", sep_token="</s>", pad_token="<pad>", cls_token="</s>",
mask_token="<special1>", additional_special_tokens=["<special0>", mask_token="<special1>", additional_special_tokens=["<special0>",
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>", "<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
"<special6>", "<special7>", "<special8>", "<special9>"], **kwargs): "<special6>", "<special7>", "<special8>", "<special9>"],
lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True,
**kwargs):
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token, super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
sep_token=sep_token, pad_token=pad_token, sep_token=sep_token, pad_token=pad_token,
cls_token=cls_token, mask_token=mask_token, cls_token=cls_token, mask_token=mask_token,
additional_special_tokens=additional_special_tokens, additional_special_tokens=additional_special_tokens,
**kwargs) **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens # cache of sm.MosesPunctNormalizer instance
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = dict()
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
# True for current supported model (v1.2.0), False for XLM-17 & 100
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
self.lang2id = lang2id
self.id2lang = id2lang
if lang2id is not None and id2lang is not None:
assert len(lang2id) == len(id2lang)
try: self.ja_word_tokenizer = None
import ftfy self.zh_word_tokenizer = None
from spacy.lang.en import English
_nlp = English()
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
self.fix_text = ftfy.fix_text
except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True)
self.fix_text = None
self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
@ -144,6 +571,43 @@ class XLMTokenizer(PreTrainedTokenizer):
self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {} self.cache = {}
def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
else:
punct_normalizer = self.cache_moses_punct_normalizer[lang]
return punct_normalizer.normalize(text)
def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
else:
moses_tokenizer = self.cache_moses_tokenizer[lang]
return moses_tokenizer.tokenize(text, return_str=False, escape=False)
def moses_pipeline(self, text, lang):
text = replace_unicode_punct(text)
text = self.moses_punct_norm(text, lang)
text = remove_non_printing_char(text)
return text
def ja_tokenize(self, text):
if self.ja_word_tokenizer is None:
try:
import Mykytea
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
except (AttributeError, ImportError) as e:
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps")
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
logger.error("2. autoreconf -i")
logger.error("3. ./configure --prefix=$HOME/local")
logger.error("4. make && make install")
logger.error("5. pip install kytea")
raise e
return list(self.ja_word_tokenizer.getWS(text))
@property @property
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
@ -191,19 +655,86 @@ class XLMTokenizer(PreTrainedTokenizer):
self.cache[token] = word self.cache[token] = word
return word return word
def _tokenize(self, text): def _tokenize(self, text, lang='en', bypass_tokenizer=False):
""" Tokenize a string. """ """
split_tokens = [] Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizerself. Otherwise, we use Moses.
if self.fix_text is None:
# Using BERT's BasicTokenizer Details of tokenization:
text = self.nlp.tokenize(text) - [sacremoses](https://github.com/alvations/sacremoses): port of Moses
for token in text: - Install with `pip install sacremoses`
split_tokens.extend([t for t in self.bpe(token).split(' ')]) - [pythainlp](https://github.com/PyThaiNLP/pythainlp): Thai tokenizer
- Install with `pip install pythainlp`
- [kytea](https://github.com/chezou/Mykytea-python): Japanese tokenizer, wrapper of [KyTea](https://github.com/neubig/kytea)
- Install with the following steps:
```
git clone git@github.com:neubig/kytea.git && cd kytea
autoreconf -i
./configure --prefix=$HOME/local
make && make install
pip install kytea
```
- [jieba](https://github.com/fxsjy/jieba): Chinese tokenizer *
- Install with `pip install jieba`
\* The original XLM used [Stanford Segmenter](https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip).
However, the wrapper (`nltk.tokenize.stanford_segmenter`) is slow due to JVM overhead, and it will be deprecated.
Jieba is a lot faster and pip-installable. Note there is some mismatch with the Stanford Segmenter. It should be fine
if you fine-tune the model with Chinese supervisionself. If you want the same exact behaviour, use the original XLM
[preprocessing script](https://github.com/facebookresearch/XLM/tree/master/tools) to tokenize the sentence externally,
and set `bypass_tokenizer=True` to bypass the tokenizer.
Args:
- lang: ISO language code (default = 'en') (string). Languages should belong of the model supported languages. However, we don't enforce it.
- bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE.
Returns:
List of tokens.
"""
if lang and self.lang2id and lang not in self.lang2id:
logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.")
if bypass_tokenizer:
text = text.split()
elif lang not in self.lang_with_custom_tokenizer:
text = self.moses_pipeline(text, lang=lang)
# TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
if lang == 'ro':
text = romanian_preprocessing(text)
text = self.moses_tokenize(text, lang=lang)
elif lang == 'th':
text = self.moses_pipeline(text, lang=lang)
try:
if 'pythainlp' not in sys.modules:
from pythainlp.tokenize import word_tokenize as th_word_tokenize
except (AttributeError, ImportError) as e:
logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps")
logger.error("1. pip install pythainlp")
raise e
text = th_word_tokenize(text)
elif lang == 'zh':
try:
if 'jieba' not in sys.modules:
import jieba
except (AttributeError, ImportError) as e:
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
logger.error("1. pip install jieba")
raise e
text = ' '.join(jieba.cut(text))
text = self.moses_pipeline(text, lang=lang)
text = text.split()
elif lang == 'ja':
text = self.moses_pipeline(text, lang=lang)
text = self.ja_tokenize(text)
else: else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT) raise ValueError('It should not reach here')
text = self.nlp(text_standardize(self.fix_text(text)))
for token in text: if self.do_lowercase_and_remove_accent and not bypass_tokenizer:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) text = lowercase_and_remove_accent(text)
split_tokens = []
for token in text:
if token:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):

View File

@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, max_len=None, def __init__(self, vocab_file,
do_lower_case=False, remove_space=True, keep_accents=False, do_lower_case=False, remove_space=True, keep_accents=False,
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>", bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>", pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",

View File

@ -9,4 +9,6 @@ requests
# For OpenAI GPT # For OpenAI GPT
regex regex
# For XLNet # For XLNet
sentencepiece sentencepiece
# For XLM
sacremoses

View File

@ -55,7 +55,8 @@ setup(
'requests', 'requests',
'tqdm', 'tqdm',
'regex', 'regex',
'sentencepiece'], 'sentencepiece',
'sacremoses'],
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
"pytorch_transformers=pytorch_transformers.__main__:main", "pytorch_transformers=pytorch_transformers.__main__:main",