mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #1092 from shijie-wu/xlm-tokenization
Added cleaned configuration properties for tokenizer with serialization - improve tokenization of XLM
This commit is contained in:
commit
d2f21f08f5
@ -44,6 +44,8 @@ XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
|
||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin",
|
||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin",
|
||||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json",
|
||||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json",
|
||||
}
|
||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
||||
@ -54,6 +56,8 @@ XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
|
||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
|
||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
|
||||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
|
||||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
|
||||
}
|
||||
|
||||
|
||||
@ -114,6 +118,7 @@ class XLMConfig(PretrainedConfig):
|
||||
causal=False,
|
||||
asm=False,
|
||||
n_langs=1,
|
||||
use_lang_emb=True,
|
||||
max_position_embeddings=512,
|
||||
embed_init_std=2048 ** -0.5,
|
||||
layer_norm_eps=1e-12,
|
||||
@ -157,6 +162,7 @@ class XLMConfig(PretrainedConfig):
|
||||
self.causal = causal
|
||||
self.asm = asm
|
||||
self.n_langs = n_langs
|
||||
self.use_lang_emb = use_lang_emb
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.bos_index = bos_index
|
||||
self.eos_index = eos_index
|
||||
@ -488,7 +494,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
|
||||
"""
|
||||
ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output',
|
||||
'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads',
|
||||
'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads',
|
||||
'hidden_dim', 'dropout', 'attention_dropout', 'asm',
|
||||
'asm_cutoffs', 'asm_div_value']
|
||||
|
||||
@ -507,6 +513,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
|
||||
# dictionary / languages
|
||||
self.n_langs = config.n_langs
|
||||
self.use_lang_emb = config.use_lang_emb
|
||||
self.n_words = config.n_words
|
||||
self.eos_index = config.eos_index
|
||||
self.pad_index = config.pad_index
|
||||
@ -529,7 +536,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
|
||||
if config.sinusoidal_embeddings:
|
||||
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
|
||||
if config.n_langs > 1:
|
||||
if config.n_langs > 1 and config.use_lang_emb:
|
||||
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
|
||||
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
|
||||
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
|
||||
@ -628,7 +635,7 @@ class XLMModel(XLMPreTrainedModel):
|
||||
# embeddings
|
||||
tensor = self.embeddings(input_ids)
|
||||
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
|
||||
if langs is not None:
|
||||
if langs is not None and self.use_lang_emb:
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
if token_type_ids is not None:
|
||||
tensor = tensor + self.embeddings(token_type_ids)
|
||||
|
@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
|
@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest):
|
||||
|
||||
tokenizer_class = DistilBertTokenizer
|
||||
|
||||
def get_tokenizer(self):
|
||||
return DistilBertTokenizer.from_pretrained(self.tmpdirname)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
|
@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
|
@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
|
@ -49,23 +49,32 @@ class CommonTestCases:
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self):
|
||||
def get_tokenizer(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_input_output_texts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_save_and_load_tokenizer(self):
|
||||
# safety check on max_len default value so we are sure the test works
|
||||
tokenizer = self.get_tokenizer()
|
||||
self.assertNotEqual(tokenizer.max_len, 42)
|
||||
|
||||
# Now let's start the test
|
||||
tokenizer = self.get_tokenizer(max_len=42)
|
||||
|
||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
||||
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
|
||||
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
|
||||
self.assertEqual(tokenizer.max_len, 42)
|
||||
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43)
|
||||
self.assertEqual(tokenizer.max_len, 43)
|
||||
|
||||
def test_pickle_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs['lower_case'] = True
|
||||
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"<unk> UNwanted , running"
|
||||
|
@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self):
|
||||
return XLMTokenizer.from_pretrained(self.tmpdirname)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
|
@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self):
|
||||
return XLNetTokenizer.from_pretrained(self.tmpdirname)
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"This is a test"
|
||||
|
@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'bert-base-cased-finetuned-mrpc': 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
'bert-base-uncased': {'do_lower_case': True},
|
||||
'bert-large-uncased': {'do_lower_case': True},
|
||||
'bert-base-cased': {'do_lower_case': False},
|
||||
'bert-large-cased': {'do_lower_case': False},
|
||||
'bert-base-multilingual-uncased': {'do_lower_case': True},
|
||||
'bert-base-multilingual-cased': {'do_lower_case': False},
|
||||
'bert-base-chinese': {'do_lower_case': False},
|
||||
'bert-base-german-cased': {'do_lower_case': False},
|
||||
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
|
||||
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
|
||||
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
|
||||
}
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||
@ -202,24 +220,6 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
index += 1
|
||||
return (vocab_file,)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
|
||||
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
|
||||
"you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = False
|
||||
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
|
||||
"but you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = True
|
||||
|
||||
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
@ -95,7 +95,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
# in a library like ours, at all.
|
||||
vocab_dict = torch.load(pretrained_vocab_file)
|
||||
for key, value in vocab_dict.items():
|
||||
self.__dict__[key] = value
|
||||
if key not in self.__dict__:
|
||||
self.__dict__[key] = value
|
||||
|
||||
if vocab_file is not None:
|
||||
self.build_vocab()
|
||||
|
@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
import json
|
||||
import six
|
||||
import copy
|
||||
from io import open
|
||||
|
||||
from .file_utils import cached_path
|
||||
@ -28,6 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
|
||||
ADDED_TOKENS_FILE = 'added_tokens.json'
|
||||
TOKENIZER_CONFIG_FILE = 'tokenizer_config.json'
|
||||
|
||||
class PreTrainedTokenizer(object):
|
||||
""" Base class for all tokenizers.
|
||||
@ -40,6 +42,7 @@ class PreTrainedTokenizer(object):
|
||||
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
|
||||
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
|
||||
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
|
||||
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
|
||||
|
||||
Parameters:
|
||||
|
||||
@ -61,6 +64,7 @@ class PreTrainedTokenizer(object):
|
||||
"""
|
||||
vocab_files_names = {}
|
||||
pretrained_vocab_files_map = {}
|
||||
pretrained_init_configuration = {}
|
||||
max_model_input_sizes = {}
|
||||
|
||||
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
|
||||
@ -166,12 +170,15 @@ class PreTrainedTokenizer(object):
|
||||
self._additional_special_tokens = []
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.max_len_single_sentence = self.max_len
|
||||
self.max_len_sentences_pair = self.max_len
|
||||
|
||||
# Added tokens
|
||||
self.added_tokens_encoder = {}
|
||||
self.added_tokens_decoder = {}
|
||||
|
||||
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
|
||||
self.init_inputs = ()
|
||||
self.init_kwargs = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||
if key == 'additional_special_tokens':
|
||||
@ -231,17 +238,20 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
|
||||
s3_models = list(cls.max_model_input_sizes.keys())
|
||||
vocab_files = {}
|
||||
init_configuration = {}
|
||||
if pretrained_model_name_or_path in s3_models:
|
||||
# Get the vocabulary from AWS S3 bucket
|
||||
for file_id, map_list in cls.pretrained_vocab_files_map.items():
|
||||
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
|
||||
if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
|
||||
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
|
||||
else:
|
||||
# Get the vocabulary from local files
|
||||
logger.info(
|
||||
@ -264,15 +274,17 @@ class PreTrainedTokenizer(object):
|
||||
vocab_files[file_id] = full_file_name
|
||||
|
||||
# Look for the additional tokens files
|
||||
all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
|
||||
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}
|
||||
additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
|
||||
'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE,
|
||||
'tokenizer_config_file': TOKENIZER_CONFIG_FILE,
|
||||
}
|
||||
|
||||
# If a path to a file was provided, get the parent directory
|
||||
saved_directory = pretrained_model_name_or_path
|
||||
if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
|
||||
saved_directory = os.path.dirname(saved_directory)
|
||||
|
||||
for file_id, file_name in all_vocab_files_names.items():
|
||||
for file_id, file_name in additional_files_names.items():
|
||||
full_file_name = os.path.join(saved_directory, file_name)
|
||||
if not os.path.exists(full_file_name):
|
||||
logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
|
||||
@ -315,28 +327,46 @@ class PreTrainedTokenizer(object):
|
||||
logger.info("loading file {} from cache at {}".format(
|
||||
file_path, resolved_vocab_files[file_id]))
|
||||
|
||||
# Prepare tokenizer initialization kwargs
|
||||
# Did we saved some inputs and kwargs to reload ?
|
||||
tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
|
||||
if tokenizer_config_file is not None:
|
||||
init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
|
||||
saved_init_inputs = init_kwargs.pop('init_inputs', ())
|
||||
if not init_inputs:
|
||||
init_inputs = saved_init_inputs
|
||||
else:
|
||||
init_kwargs = init_configuration
|
||||
|
||||
# Update with newly provided kwargs
|
||||
init_kwargs.update(kwargs)
|
||||
|
||||
# Set max length if needed
|
||||
if pretrained_model_name_or_path in cls.max_model_input_sizes:
|
||||
# if we're using a pretrained model, ensure the tokenizer
|
||||
# wont index sequences longer than the number of positional embeddings
|
||||
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
|
||||
if max_len is not None and isinstance(max_len, (int, float)):
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
|
||||
|
||||
# Merge resolved_vocab_files arguments in kwargs.
|
||||
# Merge resolved_vocab_files arguments in init_kwargs.
|
||||
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
||||
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
|
||||
for args_name, file_path in resolved_vocab_files.items():
|
||||
if args_name not in kwargs:
|
||||
kwargs[args_name] = file_path
|
||||
if args_name not in init_kwargs:
|
||||
init_kwargs[args_name] = file_path
|
||||
if special_tokens_map_file is not None:
|
||||
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
|
||||
for key, value in special_tokens_map.items():
|
||||
if key not in kwargs:
|
||||
kwargs[key] = value
|
||||
if key not in init_kwargs:
|
||||
init_kwargs[key] = value
|
||||
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(*inputs, **kwargs)
|
||||
tokenizer = cls(*init_inputs, **init_kwargs)
|
||||
|
||||
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
|
||||
tokenizer.init_inputs = init_inputs
|
||||
tokenizer.init_kwargs = init_kwargs
|
||||
|
||||
# Add supplementary tokens.
|
||||
if added_tokens_file is not None:
|
||||
@ -349,8 +379,13 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save the tokenizer vocabulary files (with added tokens) and the
|
||||
special-tokens-to-class-attributes-mapping to a directory.
|
||||
""" Save the tokenizer vocabulary files together with:
|
||||
- added tokens,
|
||||
- special-tokens-to-class-attributes-mapping,
|
||||
- tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
|
||||
|
||||
This won't save modifications other than (added tokens and special token mapping) you may have
|
||||
applied to the tokenizer after the instantion (e.g. modifying tokenizer.do_lower_case after creation).
|
||||
|
||||
This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
|
||||
"""
|
||||
@ -360,6 +395,15 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
|
||||
added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
|
||||
tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
|
||||
|
||||
tokenizer_config = copy.deepcopy(self.init_kwargs)
|
||||
tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
|
||||
for file_id in self.vocab_files_names.keys():
|
||||
tokenizer_config.pop(file_id, None)
|
||||
|
||||
with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||
|
||||
with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
|
||||
@ -566,7 +610,7 @@ class PreTrainedTokenizer(object):
|
||||
def _convert_token_to_id(self, token):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self, text, text_pair=None, add_special_tokens=False):
|
||||
def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs):
|
||||
"""
|
||||
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
|
||||
@ -577,15 +621,16 @@ class PreTrainedTokenizer(object):
|
||||
text_pair: Optional second sequence to be encoded.
|
||||
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
|
||||
to their model.
|
||||
**kwargs: passed to the `self.tokenize()` method
|
||||
"""
|
||||
if text_pair is None:
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
|
||||
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
|
||||
else:
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)]
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
||||
|
@ -20,8 +20,12 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import unicodedata
|
||||
from io import open
|
||||
|
||||
import sacremoses as sm
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_bert import BasicTokenizer
|
||||
|
||||
@ -43,6 +47,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
|
||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
|
||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
|
||||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
|
||||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
|
||||
},
|
||||
'merges_file':
|
||||
{
|
||||
@ -54,6 +60,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
|
||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
|
||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
|
||||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
|
||||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
|
||||
},
|
||||
}
|
||||
|
||||
@ -66,6 +74,342 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'xlm-mlm-xnli15-1024': 512,
|
||||
'xlm-clm-enfr-1024': 512,
|
||||
'xlm-clm-ende-1024': 512,
|
||||
'xlm-mlm-17-1280': 512,
|
||||
'xlm-mlm-100-1280': 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True},
|
||||
'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "de",
|
||||
"1": "en"},
|
||||
"lang2id": { "de": 0,
|
||||
"en": 1 }},
|
||||
'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "en",
|
||||
"1": "fr"},
|
||||
"lang2id": { "en": 0,
|
||||
"fr": 1 }},
|
||||
'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "en",
|
||||
"1": "ro"},
|
||||
"lang2id": { "en": 0,
|
||||
"ro": 1 }},
|
||||
'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "ar",
|
||||
"1": "bg",
|
||||
"2": "de",
|
||||
"3": "el",
|
||||
"4": "en",
|
||||
"5": "es",
|
||||
"6": "fr",
|
||||
"7": "hi",
|
||||
"8": "ru",
|
||||
"9": "sw",
|
||||
"10": "th",
|
||||
"11": "tr",
|
||||
"12": "ur",
|
||||
"13": "vi",
|
||||
"14": "zh"},
|
||||
"lang2id": { "ar": 0,
|
||||
"bg": 1,
|
||||
"de": 2,
|
||||
"el": 3,
|
||||
"en": 4,
|
||||
"es": 5,
|
||||
"fr": 6,
|
||||
"hi": 7,
|
||||
"ru": 8,
|
||||
"sw": 9,
|
||||
"th": 10,
|
||||
"tr": 11,
|
||||
"ur": 12,
|
||||
"vi": 13,
|
||||
"zh": 14 }},
|
||||
'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "ar",
|
||||
"1": "bg",
|
||||
"2": "de",
|
||||
"3": "el",
|
||||
"4": "en",
|
||||
"5": "es",
|
||||
"6": "fr",
|
||||
"7": "hi",
|
||||
"8": "ru",
|
||||
"9": "sw",
|
||||
"10": "th",
|
||||
"11": "tr",
|
||||
"12": "ur",
|
||||
"13": "vi",
|
||||
"14": "zh"},
|
||||
"lang2id": { "ar": 0,
|
||||
"bg": 1,
|
||||
"de": 2,
|
||||
"el": 3,
|
||||
"en": 4,
|
||||
"es": 5,
|
||||
"fr": 6,
|
||||
"hi": 7,
|
||||
"ru": 8,
|
||||
"sw": 9,
|
||||
"th": 10,
|
||||
"tr": 11,
|
||||
"ur": 12,
|
||||
"vi": 13,
|
||||
"zh": 14 }},
|
||||
'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "en",
|
||||
"1": "fr"},
|
||||
"lang2id": { "en": 0,
|
||||
"fr": 1 }},
|
||||
'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "de",
|
||||
"1": "en"},
|
||||
"lang2id": { "de": 0,
|
||||
"en": 1 }},
|
||||
'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False,
|
||||
"id2lang": {
|
||||
"0": "ar",
|
||||
"1": "de",
|
||||
"2": "en",
|
||||
"3": "es",
|
||||
"4": "fr",
|
||||
"5": "hi",
|
||||
"6": "it",
|
||||
"7": "ja",
|
||||
"8": "ko",
|
||||
"9": "nl",
|
||||
"10": "pl",
|
||||
"11": "pt",
|
||||
"12": "ru",
|
||||
"13": "sv",
|
||||
"14": "tr",
|
||||
"15": "vi",
|
||||
"16": "zh"
|
||||
},
|
||||
"lang2id": {
|
||||
"ar": 0,
|
||||
"de": 1,
|
||||
"en": 2,
|
||||
"es": 3,
|
||||
"fr": 4,
|
||||
"hi": 5,
|
||||
"it": 6,
|
||||
"ja": 7,
|
||||
"ko": 8,
|
||||
"nl": 9,
|
||||
"pl": 10,
|
||||
"pt": 11,
|
||||
"ru": 12,
|
||||
"sv": 13,
|
||||
"tr": 14,
|
||||
"vi": 15,
|
||||
"zh": 16}},
|
||||
'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False,
|
||||
"id2lang": {
|
||||
"0": "af",
|
||||
"1": "als",
|
||||
"2": "am",
|
||||
"3": "an",
|
||||
"4": "ang",
|
||||
"5": "ar",
|
||||
"6": "arz",
|
||||
"7": "ast",
|
||||
"8": "az",
|
||||
"9": "bar",
|
||||
"10": "be",
|
||||
"11": "bg",
|
||||
"12": "bn",
|
||||
"13": "br",
|
||||
"14": "bs",
|
||||
"15": "ca",
|
||||
"16": "ceb",
|
||||
"17": "ckb",
|
||||
"18": "cs",
|
||||
"19": "cy",
|
||||
"20": "da",
|
||||
"21": "de",
|
||||
"22": "el",
|
||||
"23": "en",
|
||||
"24": "eo",
|
||||
"25": "es",
|
||||
"26": "et",
|
||||
"27": "eu",
|
||||
"28": "fa",
|
||||
"29": "fi",
|
||||
"30": "fr",
|
||||
"31": "fy",
|
||||
"32": "ga",
|
||||
"33": "gan",
|
||||
"34": "gl",
|
||||
"35": "gu",
|
||||
"36": "he",
|
||||
"37": "hi",
|
||||
"38": "hr",
|
||||
"39": "hu",
|
||||
"40": "hy",
|
||||
"41": "ia",
|
||||
"42": "id",
|
||||
"43": "is",
|
||||
"44": "it",
|
||||
"45": "ja",
|
||||
"46": "jv",
|
||||
"47": "ka",
|
||||
"48": "kk",
|
||||
"49": "kn",
|
||||
"50": "ko",
|
||||
"51": "ku",
|
||||
"52": "la",
|
||||
"53": "lb",
|
||||
"54": "lt",
|
||||
"55": "lv",
|
||||
"56": "mk",
|
||||
"57": "ml",
|
||||
"58": "mn",
|
||||
"59": "mr",
|
||||
"60": "ms",
|
||||
"61": "my",
|
||||
"62": "nds",
|
||||
"63": "ne",
|
||||
"64": "nl",
|
||||
"65": "nn",
|
||||
"66": "no",
|
||||
"67": "oc",
|
||||
"68": "pl",
|
||||
"69": "pt",
|
||||
"70": "ro",
|
||||
"71": "ru",
|
||||
"72": "scn",
|
||||
"73": "sco",
|
||||
"74": "sh",
|
||||
"75": "si",
|
||||
"76": "simple",
|
||||
"77": "sk",
|
||||
"78": "sl",
|
||||
"79": "sq",
|
||||
"80": "sr",
|
||||
"81": "sv",
|
||||
"82": "sw",
|
||||
"83": "ta",
|
||||
"84": "te",
|
||||
"85": "th",
|
||||
"86": "tl",
|
||||
"87": "tr",
|
||||
"88": "tt",
|
||||
"89": "uk",
|
||||
"90": "ur",
|
||||
"91": "uz",
|
||||
"92": "vi",
|
||||
"93": "war",
|
||||
"94": "wuu",
|
||||
"95": "yi",
|
||||
"96": "zh",
|
||||
"97": "zh_classical",
|
||||
"98": "zh_min_nan",
|
||||
"99": "zh_yue"
|
||||
},
|
||||
"lang2id": {
|
||||
"af": 0,
|
||||
"als": 1,
|
||||
"am": 2,
|
||||
"an": 3,
|
||||
"ang": 4,
|
||||
"ar": 5,
|
||||
"arz": 6,
|
||||
"ast": 7,
|
||||
"az": 8,
|
||||
"bar": 9,
|
||||
"be": 10,
|
||||
"bg": 11,
|
||||
"bn": 12,
|
||||
"br": 13,
|
||||
"bs": 14,
|
||||
"ca": 15,
|
||||
"ceb": 16,
|
||||
"ckb": 17,
|
||||
"cs": 18,
|
||||
"cy": 19,
|
||||
"da": 20,
|
||||
"de": 21,
|
||||
"el": 22,
|
||||
"en": 23,
|
||||
"eo": 24,
|
||||
"es": 25,
|
||||
"et": 26,
|
||||
"eu": 27,
|
||||
"fa": 28,
|
||||
"fi": 29,
|
||||
"fr": 30,
|
||||
"fy": 31,
|
||||
"ga": 32,
|
||||
"gan": 33,
|
||||
"gl": 34,
|
||||
"gu": 35,
|
||||
"he": 36,
|
||||
"hi": 37,
|
||||
"hr": 38,
|
||||
"hu": 39,
|
||||
"hy": 40,
|
||||
"ia": 41,
|
||||
"id": 42,
|
||||
"is": 43,
|
||||
"it": 44,
|
||||
"ja": 45,
|
||||
"jv": 46,
|
||||
"ka": 47,
|
||||
"kk": 48,
|
||||
"kn": 49,
|
||||
"ko": 50,
|
||||
"ku": 51,
|
||||
"la": 52,
|
||||
"lb": 53,
|
||||
"lt": 54,
|
||||
"lv": 55,
|
||||
"mk": 56,
|
||||
"ml": 57,
|
||||
"mn": 58,
|
||||
"mr": 59,
|
||||
"ms": 60,
|
||||
"my": 61,
|
||||
"nds": 62,
|
||||
"ne": 63,
|
||||
"nl": 64,
|
||||
"nn": 65,
|
||||
"no": 66,
|
||||
"oc": 67,
|
||||
"pl": 68,
|
||||
"pt": 69,
|
||||
"ro": 70,
|
||||
"ru": 71,
|
||||
"scn": 72,
|
||||
"sco": 73,
|
||||
"sh": 74,
|
||||
"si": 75,
|
||||
"simple": 76,
|
||||
"sk": 77,
|
||||
"sl": 78,
|
||||
"sq": 79,
|
||||
"sr": 80,
|
||||
"sv": 81,
|
||||
"sw": 82,
|
||||
"ta": 83,
|
||||
"te": 84,
|
||||
"th": 85,
|
||||
"tl": 86,
|
||||
"tr": 87,
|
||||
"tt": 88,
|
||||
"uk": 89,
|
||||
"ur": 90,
|
||||
"uz": 91,
|
||||
"vi": 92,
|
||||
"war": 93,
|
||||
"wuu": 94,
|
||||
"yi": 95,
|
||||
"zh": 96,
|
||||
"zh_classical": 97,
|
||||
"zh_min_nan": 98,
|
||||
"zh_yue": 99
|
||||
}},
|
||||
}
|
||||
|
||||
def get_pairs(word):
|
||||
@ -80,62 +424,145 @@ def get_pairs(word):
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
def text_standardize(text):
|
||||
|
||||
def lowercase_and_remove_accent(text):
|
||||
"""
|
||||
fixes some issues the spacy tokenizer had on books corpus
|
||||
also does some whitespace standardization
|
||||
Lowercase and strips accents from a piece of text based on
|
||||
https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
|
||||
"""
|
||||
text = text.replace('—', '-')
|
||||
text = text.replace('–', '-')
|
||||
text = text.replace('―', '-')
|
||||
text = ' '.join(text)
|
||||
text = text.lower()
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output).lower().split(' ')
|
||||
|
||||
|
||||
def replace_unicode_punct(text):
|
||||
'''
|
||||
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
|
||||
'''
|
||||
text = text.replace(',', ',')
|
||||
text = re.sub(r'。\s*', '. ', text)
|
||||
text = text.replace('、', ',')
|
||||
text = text.replace('”', '"')
|
||||
text = text.replace('“', '"')
|
||||
text = text.replace('∶', ':')
|
||||
text = text.replace(':', ':')
|
||||
text = text.replace('?', '?')
|
||||
text = text.replace('《', '"')
|
||||
text = text.replace('》', '"')
|
||||
text = text.replace(')', ')')
|
||||
text = text.replace('!', '!')
|
||||
text = text.replace('(', '(')
|
||||
text = text.replace(';', ';')
|
||||
text = text.replace('1', '"')
|
||||
text = text.replace('」', '"')
|
||||
text = text.replace('「', '"')
|
||||
text = text.replace('0', '0')
|
||||
text = text.replace('3', '3')
|
||||
text = text.replace('2', '2')
|
||||
text = text.replace('5', '5')
|
||||
text = text.replace('6', '6')
|
||||
text = text.replace('9', '9')
|
||||
text = text.replace('7', '7')
|
||||
text = text.replace('8', '8')
|
||||
text = text.replace('4', '4')
|
||||
text = re.sub(r'.\s*', '. ', text)
|
||||
text = text.replace('~', '~')
|
||||
text = text.replace('’', '\'')
|
||||
text = text.replace('…', '...')
|
||||
text = text.replace('´', "'")
|
||||
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
|
||||
text = re.sub(r'\s*\n\s*', ' \n ', text)
|
||||
text = re.sub(r'[^\S\n]+', ' ', text)
|
||||
return text.strip()
|
||||
text = text.replace('━', '-')
|
||||
text = text.replace('〈', '<')
|
||||
text = text.replace('〉', '>')
|
||||
text = text.replace('【', '[')
|
||||
text = text.replace('】', ']')
|
||||
text = text.replace('%', '%')
|
||||
return text
|
||||
|
||||
|
||||
def remove_non_printing_char(text):
|
||||
'''
|
||||
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
|
||||
'''
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith('C'):
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def romanian_preprocessing(text):
|
||||
'''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`'''
|
||||
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
|
||||
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
|
||||
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
|
||||
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
|
||||
text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma
|
||||
text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma
|
||||
text = text.replace("\u0102", "A").replace("\u0103", "a")
|
||||
text = text.replace("\u00C2", "A").replace("\u00E2", "a")
|
||||
text = text.replace("\u00CE", "I").replace("\u00EE", "i")
|
||||
return text
|
||||
|
||||
|
||||
class XLMTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
|
||||
BPE tokenizer for XLM
|
||||
|
||||
- lower case all inputs
|
||||
- Moses preprocessing & tokenization for most supported languages
|
||||
|
||||
- uses `SpaCy tokenizer <https://spacy.io/api/tokenizer/>`_ and \
|
||||
`ftfy <https://ftfy.readthedocs.io/en/latest/>`_ for pre-BPE tokenization if they are installed, \
|
||||
fallback to BERT's BasicTokenizer if not.
|
||||
- Language specific tokenization for Chinese (Jieba), Japanese (KyTea) and Thai (PyThaiNLP)
|
||||
|
||||
- (optionally) lower case & normalize all inputs text
|
||||
|
||||
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
|
||||
(ex: "__classify__") to a vocabulary.
|
||||
(ex: "__classify__") to a vocabulary
|
||||
|
||||
- `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies)
|
||||
|
||||
- `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies)
|
||||
|
||||
- `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies)
|
||||
"""
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
|
||||
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
|
||||
mask_token="<special1>", additional_special_tokens=["<special0>",
|
||||
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
|
||||
"<special6>", "<special7>", "<special8>", "<special9>"], **kwargs):
|
||||
"<special6>", "<special7>", "<special8>", "<special9>"],
|
||||
lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True,
|
||||
**kwargs):
|
||||
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
|
||||
sep_token=sep_token, pad_token=pad_token,
|
||||
cls_token=cls_token, mask_token=mask_token,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs)
|
||||
|
||||
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
||||
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
||||
# cache of sm.MosesPunctNormalizer instance
|
||||
self.cache_moses_punct_normalizer = dict()
|
||||
# cache of sm.MosesTokenizer instance
|
||||
self.cache_moses_tokenizer = dict()
|
||||
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
|
||||
# True for current supported model (v1.2.0), False for XLM-17 & 100
|
||||
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
|
||||
self.lang2id = lang2id
|
||||
self.id2lang = id2lang
|
||||
if lang2id is not None and id2lang is not None:
|
||||
assert len(lang2id) == len(id2lang)
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
from spacy.lang.en import English
|
||||
_nlp = English()
|
||||
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
|
||||
self.fix_text = ftfy.fix_text
|
||||
except ImportError:
|
||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||
self.nlp = BasicTokenizer(do_lower_case=True)
|
||||
self.fix_text = None
|
||||
self.ja_word_tokenizer = None
|
||||
self.zh_word_tokenizer = None
|
||||
|
||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
@ -144,6 +571,43 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
|
||||
def moses_punct_norm(self, text, lang):
|
||||
if lang not in self.cache_moses_punct_normalizer:
|
||||
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
|
||||
self.cache_moses_punct_normalizer[lang] = punct_normalizer
|
||||
else:
|
||||
punct_normalizer = self.cache_moses_punct_normalizer[lang]
|
||||
return punct_normalizer.normalize(text)
|
||||
|
||||
def moses_tokenize(self, text, lang):
|
||||
if lang not in self.cache_moses_tokenizer:
|
||||
moses_tokenizer = sm.MosesTokenizer(lang=lang)
|
||||
self.cache_moses_tokenizer[lang] = moses_tokenizer
|
||||
else:
|
||||
moses_tokenizer = self.cache_moses_tokenizer[lang]
|
||||
return moses_tokenizer.tokenize(text, return_str=False, escape=False)
|
||||
|
||||
def moses_pipeline(self, text, lang):
|
||||
text = replace_unicode_punct(text)
|
||||
text = self.moses_punct_norm(text, lang)
|
||||
text = remove_non_printing_char(text)
|
||||
return text
|
||||
|
||||
def ja_tokenize(self, text):
|
||||
if self.ja_word_tokenizer is None:
|
||||
try:
|
||||
import Mykytea
|
||||
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps")
|
||||
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
|
||||
logger.error("2. autoreconf -i")
|
||||
logger.error("3. ./configure --prefix=$HOME/local")
|
||||
logger.error("4. make && make install")
|
||||
logger.error("5. pip install kytea")
|
||||
raise e
|
||||
return list(self.ja_word_tokenizer.getWS(text))
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.encoder)
|
||||
@ -191,19 +655,86 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def _tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
split_tokens = []
|
||||
if self.fix_text is None:
|
||||
# Using BERT's BasicTokenizer
|
||||
text = self.nlp.tokenize(text)
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||
def _tokenize(self, text, lang='en', bypass_tokenizer=False):
|
||||
"""
|
||||
Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizerself. Otherwise, we use Moses.
|
||||
|
||||
Details of tokenization:
|
||||
- [sacremoses](https://github.com/alvations/sacremoses): port of Moses
|
||||
- Install with `pip install sacremoses`
|
||||
- [pythainlp](https://github.com/PyThaiNLP/pythainlp): Thai tokenizer
|
||||
- Install with `pip install pythainlp`
|
||||
- [kytea](https://github.com/chezou/Mykytea-python): Japanese tokenizer, wrapper of [KyTea](https://github.com/neubig/kytea)
|
||||
- Install with the following steps:
|
||||
```
|
||||
git clone git@github.com:neubig/kytea.git && cd kytea
|
||||
autoreconf -i
|
||||
./configure --prefix=$HOME/local
|
||||
make && make install
|
||||
pip install kytea
|
||||
```
|
||||
- [jieba](https://github.com/fxsjy/jieba): Chinese tokenizer *
|
||||
- Install with `pip install jieba`
|
||||
|
||||
\* The original XLM used [Stanford Segmenter](https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip).
|
||||
However, the wrapper (`nltk.tokenize.stanford_segmenter`) is slow due to JVM overhead, and it will be deprecated.
|
||||
Jieba is a lot faster and pip-installable. Note there is some mismatch with the Stanford Segmenter. It should be fine
|
||||
if you fine-tune the model with Chinese supervisionself. If you want the same exact behaviour, use the original XLM
|
||||
[preprocessing script](https://github.com/facebookresearch/XLM/tree/master/tools) to tokenize the sentence externally,
|
||||
and set `bypass_tokenizer=True` to bypass the tokenizer.
|
||||
|
||||
Args:
|
||||
- lang: ISO language code (default = 'en') (string). Languages should belong of the model supported languages. However, we don't enforce it.
|
||||
- bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE.
|
||||
|
||||
Returns:
|
||||
List of tokens.
|
||||
"""
|
||||
if lang and self.lang2id and lang not in self.lang2id:
|
||||
logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.")
|
||||
if bypass_tokenizer:
|
||||
text = text.split()
|
||||
elif lang not in self.lang_with_custom_tokenizer:
|
||||
text = self.moses_pipeline(text, lang=lang)
|
||||
# TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
|
||||
if lang == 'ro':
|
||||
text = romanian_preprocessing(text)
|
||||
text = self.moses_tokenize(text, lang=lang)
|
||||
elif lang == 'th':
|
||||
text = self.moses_pipeline(text, lang=lang)
|
||||
try:
|
||||
if 'pythainlp' not in sys.modules:
|
||||
from pythainlp.tokenize import word_tokenize as th_word_tokenize
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps")
|
||||
logger.error("1. pip install pythainlp")
|
||||
raise e
|
||||
text = th_word_tokenize(text)
|
||||
elif lang == 'zh':
|
||||
try:
|
||||
if 'jieba' not in sys.modules:
|
||||
import jieba
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
|
||||
logger.error("1. pip install jieba")
|
||||
raise e
|
||||
text = ' '.join(jieba.cut(text))
|
||||
text = self.moses_pipeline(text, lang=lang)
|
||||
text = text.split()
|
||||
elif lang == 'ja':
|
||||
text = self.moses_pipeline(text, lang=lang)
|
||||
text = self.ja_tokenize(text)
|
||||
else:
|
||||
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
|
||||
text = self.nlp(text_standardize(self.fix_text(text)))
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||
raise ValueError('It should not reach here')
|
||||
|
||||
if self.do_lowercase_and_remove_accent and not bypass_tokenizer:
|
||||
text = lowercase_and_remove_accent(text)
|
||||
|
||||
split_tokens = []
|
||||
for token in text:
|
||||
if token:
|
||||
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||
|
||||
return split_tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
|
@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, max_len=None,
|
||||
def __init__(self, vocab_file,
|
||||
do_lower_case=False, remove_space=True, keep_accents=False,
|
||||
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
|
||||
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
|
||||
|
@ -9,4 +9,6 @@ requests
|
||||
# For OpenAI GPT
|
||||
regex
|
||||
# For XLNet
|
||||
sentencepiece
|
||||
sentencepiece
|
||||
# For XLM
|
||||
sacremoses
|
Loading…
Reference in New Issue
Block a user