diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 5a659e02f92..035787a97b2 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -44,6 +44,8 @@ XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin", + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json", } XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", @@ -54,6 +56,8 @@ XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", } @@ -114,6 +118,7 @@ class XLMConfig(PretrainedConfig): causal=False, asm=False, n_langs=1, + use_lang_emb=True, max_position_embeddings=512, embed_init_std=2048 ** -0.5, layer_norm_eps=1e-12, @@ -157,6 +162,7 @@ class XLMConfig(PretrainedConfig): self.causal = causal self.asm = asm self.n_langs = n_langs + self.use_lang_emb = use_lang_emb self.layer_norm_eps = layer_norm_eps self.bos_index = bos_index self.eos_index = eos_index @@ -488,7 +494,7 @@ class XLMModel(XLMPreTrainedModel): """ ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output', - 'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', + 'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads', 'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'asm_cutoffs', 'asm_div_value'] @@ -507,6 +513,7 @@ class XLMModel(XLMPreTrainedModel): # dictionary / languages self.n_langs = config.n_langs + self.use_lang_emb = config.use_lang_emb self.n_words = config.n_words self.eos_index = config.eos_index self.pad_index = config.pad_index @@ -529,7 +536,7 @@ class XLMModel(XLMPreTrainedModel): self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim) if config.sinusoidal_embeddings: create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) - if config.n_langs > 1: + if config.n_langs > 1 and config.use_lang_emb: self.lang_embeddings = nn.Embedding(self.n_langs, self.dim) self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) @@ -628,7 +635,7 @@ class XLMModel(XLMPreTrainedModel): # embeddings tensor = self.embeddings(input_ids) tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor) - if langs is not None: + if langs is not None and self.use_lang_emb: tensor = tensor + self.lang_embeddings(langs) if token_type_ids is not None: tensor = tensor + self.embeddings(token_type_ids) diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index aaca746d469..1111683ecc5 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - def get_tokenizer(self): - return self.tokenizer_class.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"UNwant\u00E9d,running" diff --git a/pytorch_transformers/tests/tokenization_dilbert_test.py b/pytorch_transformers/tests/tokenization_dilbert_test.py index 30268db2166..42f80609981 100644 --- a/pytorch_transformers/tests/tokenization_dilbert_test.py +++ b/pytorch_transformers/tests/tokenization_dilbert_test.py @@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest): tokenizer_class = DistilBertTokenizer - def get_tokenizer(self): - return DistilBertTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) def test_sequence_builders(self): tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") diff --git a/pytorch_transformers/tests/tokenization_gpt2_test.py b/pytorch_transformers/tests/tokenization_gpt2_test.py index da7028c27d7..252dbfe6f47 100644 --- a/pytorch_transformers/tests/tokenization_gpt2_test.py +++ b/pytorch_transformers/tests/tokenization_gpt2_test.py @@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_openai_test.py b/pytorch_transformers/tests/tokenization_openai_test.py index bb354f3fb77..6b86416d2d6 100644 --- a/pytorch_transformers/tests/tokenization_openai_test.py +++ b/pytorch_transformers/tests/tokenization_openai_test.py @@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_roberta_test.py b/pytorch_transformers/tests/tokenization_roberta_test.py index a8f940ae432..5f9b65a7a30 100644 --- a/pytorch_transformers/tests/tokenization_roberta_test.py +++ b/pytorch_transformers/tests/tokenization_roberta_test.py @@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index ebcf6f48d87..6578c5c3a56 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -49,23 +49,32 @@ class CommonTestCases: def tearDown(self): shutil.rmtree(self.tmpdirname) - def get_tokenizer(self): + def get_tokenizer(self, **kwargs): raise NotImplementedError def get_input_output_texts(self): raise NotImplementedError def test_save_and_load_tokenizer(self): + # safety check on max_len default value so we are sure the test works tokenizer = self.get_tokenizer() + self.assertNotEqual(tokenizer.max_len, 42) + + # Now let's start the test + tokenizer = self.get_tokenizer(max_len=42) before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") with TemporaryDirectory() as tmpdirname: tokenizer.save_pretrained(tmpdirname) - tokenizer = tokenizer.from_pretrained(tmpdirname) + tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) - after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") - self.assertListEqual(before_tokens, after_tokens) + after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") + self.assertListEqual(before_tokens, after_tokens) + + self.assertEqual(tokenizer.max_len, 42) + tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43) + self.assertEqual(tokenizer.max_len, 43) def test_pickle_tokenizer(self): tokenizer = self.get_tokenizer() diff --git a/pytorch_transformers/tests/tokenization_transfo_xl_test.py b/pytorch_transformers/tests/tokenization_transfo_xl_test.py index fbd06cf47e7..f881cf1d2b4 100644 --- a/pytorch_transformers/tests/tokenization_transfo_xl_test.py +++ b/pytorch_transformers/tests/tokenization_transfo_xl_test.py @@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - def get_tokenizer(self): - return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True) + def get_tokenizer(self, **kwargs): + kwargs['lower_case'] = True + return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u" UNwanted , running" diff --git a/pytorch_transformers/tests/tokenization_xlm_test.py b/pytorch_transformers/tests/tokenization_xlm_test.py index ede77a1f988..43f1e0c5dd7 100644 --- a/pytorch_transformers/tests/tokenization_xlm_test.py +++ b/pytorch_transformers/tests/tokenization_xlm_test.py @@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return XLMTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_xlnet_test.py b/pytorch_transformers/tests/tokenization_xlnet_test.py index 9feab7c0bdf..c603ce55f9d 100644 --- a/pytorch_transformers/tests/tokenization_xlnet_test.py +++ b/pytorch_transformers/tests/tokenization_xlnet_test.py @@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer.save_pretrained(self.tmpdirname) - def get_tokenizer(self): - return XLNetTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"This is a test" diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index 92f027038df..46f839c6270 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'bert-base-cased-finetuned-mrpc': 512, } +PRETRAINED_INIT_CONFIGURATION = { + 'bert-base-uncased': {'do_lower_case': True}, + 'bert-large-uncased': {'do_lower_case': True}, + 'bert-base-cased': {'do_lower_case': False}, + 'bert-large-cased': {'do_lower_case': False}, + 'bert-base-multilingual-uncased': {'do_lower_case': True}, + 'bert-base-multilingual-cased': {'do_lower_case': False}, + 'bert-base-chinese': {'do_lower_case': False}, + 'bert-base-german-cased': {'do_lower_case': False}, + 'bert-large-uncased-whole-word-masking': {'do_lower_case': True}, + 'bert-large-cased-whole-word-masking': {'do_lower_case': False}, + 'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True}, + 'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False}, + 'bert-base-cased-finetuned-mrpc': {'do_lower_case': False}, +} + + def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() @@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, @@ -202,24 +220,6 @@ class BertTokenizer(PreTrainedTokenizer): index += 1 return (vocab_file,) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): - """ Instantiate a BertTokenizer from pre-trained vocabulary files. - """ - if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: - if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is a cased model but you have not set " - "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " - "you may want to check this behavior.") - kwargs['do_lower_case'] = False - elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is an uncased model but you have set " - "`do_lower_case` to False. We are setting `do_lower_case=True` for you " - "but you may want to check this behavior.") - kwargs['do_lower_case'] = True - - return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - class BasicTokenizer(object): """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" diff --git a/pytorch_transformers/tokenization_transfo_xl.py b/pytorch_transformers/tokenization_transfo_xl.py index c603ba695c1..66bc01c1bb0 100644 --- a/pytorch_transformers/tokenization_transfo_xl.py +++ b/pytorch_transformers/tokenization_transfo_xl.py @@ -95,7 +95,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer): # in a library like ours, at all. vocab_dict = torch.load(pretrained_vocab_file) for key, value in vocab_dict.items(): - self.__dict__[key] = value + if key not in self.__dict__: + self.__dict__[key] = value if vocab_file is not None: self.build_vocab() diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 9d814f704f7..56ccdbcb915 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -20,6 +20,7 @@ import logging import os import json import six +import copy from io import open from .file_utils import cached_path @@ -28,6 +29,7 @@ logger = logging.getLogger(__name__) SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' ADDED_TOKENS_FILE = 'added_tokens.json' +TOKENIZER_CONFIG_FILE = 'tokenizer_config.json' class PreTrainedTokenizer(object): """ Base class for all tokenizers. @@ -40,6 +42,7 @@ class PreTrainedTokenizer(object): - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file. - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size. + - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method. Parameters: @@ -61,6 +64,7 @@ class PreTrainedTokenizer(object): """ vocab_files_names = {} pretrained_vocab_files_map = {} + pretrained_init_configuration = {} max_model_input_sizes = {} SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", @@ -166,12 +170,15 @@ class PreTrainedTokenizer(object): self._additional_special_tokens = [] self.max_len = max_len if max_len is not None else int(1e12) - self.max_len_single_sentence = self.max_len - self.max_len_sentences_pair = self.max_len + # Added tokens self.added_tokens_encoder = {} self.added_tokens_decoder = {} + # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) + self.init_inputs = () + self.init_kwargs = {} + for key, value in kwargs.items(): if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key == 'additional_special_tokens': @@ -231,17 +238,20 @@ class PreTrainedTokenizer(object): @classmethod - def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} + init_configuration = {} if pretrained_model_name_or_path in s3_models: # Get the vocabulary from AWS S3 bucket for file_id, map_list in cls.pretrained_vocab_files_map.items(): vocab_files[file_id] = map_list[pretrained_model_name_or_path] + if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration: + init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path] else: # Get the vocabulary from local files logger.info( @@ -264,15 +274,17 @@ class PreTrainedTokenizer(object): vocab_files[file_id] = full_file_name # Look for the additional tokens files - all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, - 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE} + additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, + 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE, + 'tokenizer_config_file': TOKENIZER_CONFIG_FILE, + } # If a path to a file was provided, get the parent directory saved_directory = pretrained_model_name_or_path if os.path.exists(saved_directory) and not os.path.isdir(saved_directory): saved_directory = os.path.dirname(saved_directory) - for file_id, file_name in all_vocab_files_names.items(): + for file_id, file_name in additional_files_names.items(): full_file_name = os.path.join(saved_directory, file_name) if not os.path.exists(full_file_name): logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) @@ -315,28 +327,46 @@ class PreTrainedTokenizer(object): logger.info("loading file {} from cache at {}".format( file_path, resolved_vocab_files[file_id])) + # Prepare tokenizer initialization kwargs + # Did we saved some inputs and kwargs to reload ? + tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None) + if tokenizer_config_file is not None: + init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8")) + saved_init_inputs = init_kwargs.pop('init_inputs', ()) + if not init_inputs: + init_inputs = saved_init_inputs + else: + init_kwargs = init_configuration + + # Update with newly provided kwargs + init_kwargs.update(kwargs) + # Set max length if needed if pretrained_model_name_or_path in cls.max_model_input_sizes: # if we're using a pretrained model, ensure the tokenizer # wont index sequences longer than the number of positional embeddings max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] if max_len is not None and isinstance(max_len, (int, float)): - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len) - # Merge resolved_vocab_files arguments in kwargs. + # Merge resolved_vocab_files arguments in init_kwargs. added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) for args_name, file_path in resolved_vocab_files.items(): - if args_name not in kwargs: - kwargs[args_name] = file_path + if args_name not in init_kwargs: + init_kwargs[args_name] = file_path if special_tokens_map_file is not None: special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8")) for key, value in special_tokens_map.items(): - if key not in kwargs: - kwargs[key] = value + if key not in init_kwargs: + init_kwargs[key] = value # Instantiate tokenizer. - tokenizer = cls(*inputs, **kwargs) + tokenizer = cls(*init_inputs, **init_kwargs) + + # Save inputs and kwargs for saving and re-loading with ``save_pretrained`` + tokenizer.init_inputs = init_inputs + tokenizer.init_kwargs = init_kwargs # Add supplementary tokens. if added_tokens_file is not None: @@ -349,8 +379,13 @@ class PreTrainedTokenizer(object): def save_pretrained(self, save_directory): - """ Save the tokenizer vocabulary files (with added tokens) and the - special-tokens-to-class-attributes-mapping to a directory. + """ Save the tokenizer vocabulary files together with: + - added tokens, + - special-tokens-to-class-attributes-mapping, + - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert). + + This won't save modifications other than (added tokens and special token mapping) you may have + applied to the tokenizer after the instantion (e.g. modifying tokenizer.do_lower_case after creation). This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method. """ @@ -360,6 +395,15 @@ class PreTrainedTokenizer(object): special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) + tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE) + + tokenizer_config = copy.deepcopy(self.init_kwargs) + tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs) + for file_id in self.vocab_files_names.keys(): + tokenizer_config.pop(file_id, None) + + with open(tokenizer_config_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(tokenizer_config, ensure_ascii=False)) with open(special_tokens_map_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) @@ -566,7 +610,7 @@ class PreTrainedTokenizer(object): def _convert_token_to_id(self, token): raise NotImplementedError - def encode(self, text, text_pair=None, add_special_tokens=False): + def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. @@ -577,15 +621,16 @@ class PreTrainedTokenizer(object): text_pair: Optional second sequence to be encoded. add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative to their model. + **kwargs: passed to the `self.tokenize()` method """ if text_pair is None: if add_special_tokens: - return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text))) + return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs))) else: - return self.convert_tokens_to_ids(self.tokenize(text)) + return self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) - first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)] - second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)] + first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] + second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if add_special_tokens: return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 2b930458bbf..d14acb39c63 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -20,8 +20,12 @@ import json import logging import os import re +import sys +import unicodedata from io import open +import sacremoses as sm + from .tokenization_utils import PreTrainedTokenizer from .tokenization_bert import BasicTokenizer @@ -43,6 +47,8 @@ PRETRAINED_VOCAB_FILES_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json", + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", }, 'merges_file': { @@ -54,6 +60,8 @@ PRETRAINED_VOCAB_FILES_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt", }, } @@ -66,6 +74,342 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'xlm-mlm-xnli15-1024': 512, 'xlm-clm-enfr-1024': 512, 'xlm-clm-ende-1024': 512, + 'xlm-mlm-17-1280': 512, + 'xlm-mlm-100-1280': 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + 'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True}, + 'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "de", + "1": "en"}, + "lang2id": { "de": 0, + "en": 1 }}, + 'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "en", + "1": "fr"}, + "lang2id": { "en": 0, + "fr": 1 }}, + 'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "en", + "1": "ro"}, + "lang2id": { "en": 0, + "ro": 1 }}, + 'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "ar", + "1": "bg", + "2": "de", + "3": "el", + "4": "en", + "5": "es", + "6": "fr", + "7": "hi", + "8": "ru", + "9": "sw", + "10": "th", + "11": "tr", + "12": "ur", + "13": "vi", + "14": "zh"}, + "lang2id": { "ar": 0, + "bg": 1, + "de": 2, + "el": 3, + "en": 4, + "es": 5, + "fr": 6, + "hi": 7, + "ru": 8, + "sw": 9, + "th": 10, + "tr": 11, + "ur": 12, + "vi": 13, + "zh": 14 }}, + 'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "ar", + "1": "bg", + "2": "de", + "3": "el", + "4": "en", + "5": "es", + "6": "fr", + "7": "hi", + "8": "ru", + "9": "sw", + "10": "th", + "11": "tr", + "12": "ur", + "13": "vi", + "14": "zh"}, + "lang2id": { "ar": 0, + "bg": 1, + "de": 2, + "el": 3, + "en": 4, + "es": 5, + "fr": 6, + "hi": 7, + "ru": 8, + "sw": 9, + "th": 10, + "tr": 11, + "ur": 12, + "vi": 13, + "zh": 14 }}, + 'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "en", + "1": "fr"}, + "lang2id": { "en": 0, + "fr": 1 }}, + 'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "de", + "1": "en"}, + "lang2id": { "de": 0, + "en": 1 }}, + 'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False, + "id2lang": { + "0": "ar", + "1": "de", + "2": "en", + "3": "es", + "4": "fr", + "5": "hi", + "6": "it", + "7": "ja", + "8": "ko", + "9": "nl", + "10": "pl", + "11": "pt", + "12": "ru", + "13": "sv", + "14": "tr", + "15": "vi", + "16": "zh" + }, + "lang2id": { + "ar": 0, + "de": 1, + "en": 2, + "es": 3, + "fr": 4, + "hi": 5, + "it": 6, + "ja": 7, + "ko": 8, + "nl": 9, + "pl": 10, + "pt": 11, + "ru": 12, + "sv": 13, + "tr": 14, + "vi": 15, + "zh": 16}}, + 'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False, + "id2lang": { + "0": "af", + "1": "als", + "2": "am", + "3": "an", + "4": "ang", + "5": "ar", + "6": "arz", + "7": "ast", + "8": "az", + "9": "bar", + "10": "be", + "11": "bg", + "12": "bn", + "13": "br", + "14": "bs", + "15": "ca", + "16": "ceb", + "17": "ckb", + "18": "cs", + "19": "cy", + "20": "da", + "21": "de", + "22": "el", + "23": "en", + "24": "eo", + "25": "es", + "26": "et", + "27": "eu", + "28": "fa", + "29": "fi", + "30": "fr", + "31": "fy", + "32": "ga", + "33": "gan", + "34": "gl", + "35": "gu", + "36": "he", + "37": "hi", + "38": "hr", + "39": "hu", + "40": "hy", + "41": "ia", + "42": "id", + "43": "is", + "44": "it", + "45": "ja", + "46": "jv", + "47": "ka", + "48": "kk", + "49": "kn", + "50": "ko", + "51": "ku", + "52": "la", + "53": "lb", + "54": "lt", + "55": "lv", + "56": "mk", + "57": "ml", + "58": "mn", + "59": "mr", + "60": "ms", + "61": "my", + "62": "nds", + "63": "ne", + "64": "nl", + "65": "nn", + "66": "no", + "67": "oc", + "68": "pl", + "69": "pt", + "70": "ro", + "71": "ru", + "72": "scn", + "73": "sco", + "74": "sh", + "75": "si", + "76": "simple", + "77": "sk", + "78": "sl", + "79": "sq", + "80": "sr", + "81": "sv", + "82": "sw", + "83": "ta", + "84": "te", + "85": "th", + "86": "tl", + "87": "tr", + "88": "tt", + "89": "uk", + "90": "ur", + "91": "uz", + "92": "vi", + "93": "war", + "94": "wuu", + "95": "yi", + "96": "zh", + "97": "zh_classical", + "98": "zh_min_nan", + "99": "zh_yue" + }, + "lang2id": { + "af": 0, + "als": 1, + "am": 2, + "an": 3, + "ang": 4, + "ar": 5, + "arz": 6, + "ast": 7, + "az": 8, + "bar": 9, + "be": 10, + "bg": 11, + "bn": 12, + "br": 13, + "bs": 14, + "ca": 15, + "ceb": 16, + "ckb": 17, + "cs": 18, + "cy": 19, + "da": 20, + "de": 21, + "el": 22, + "en": 23, + "eo": 24, + "es": 25, + "et": 26, + "eu": 27, + "fa": 28, + "fi": 29, + "fr": 30, + "fy": 31, + "ga": 32, + "gan": 33, + "gl": 34, + "gu": 35, + "he": 36, + "hi": 37, + "hr": 38, + "hu": 39, + "hy": 40, + "ia": 41, + "id": 42, + "is": 43, + "it": 44, + "ja": 45, + "jv": 46, + "ka": 47, + "kk": 48, + "kn": 49, + "ko": 50, + "ku": 51, + "la": 52, + "lb": 53, + "lt": 54, + "lv": 55, + "mk": 56, + "ml": 57, + "mn": 58, + "mr": 59, + "ms": 60, + "my": 61, + "nds": 62, + "ne": 63, + "nl": 64, + "nn": 65, + "no": 66, + "oc": 67, + "pl": 68, + "pt": 69, + "ro": 70, + "ru": 71, + "scn": 72, + "sco": 73, + "sh": 74, + "si": 75, + "simple": 76, + "sk": 77, + "sl": 78, + "sq": 79, + "sr": 80, + "sv": 81, + "sw": 82, + "ta": 83, + "te": 84, + "th": 85, + "tl": 86, + "tr": 87, + "tt": 88, + "uk": 89, + "ur": 90, + "uz": 91, + "vi": 92, + "war": 93, + "wuu": 94, + "yi": 95, + "zh": 96, + "zh_classical": 97, + "zh_min_nan": 98, + "zh_yue": 99 + }}, } def get_pairs(word): @@ -80,62 +424,145 @@ def get_pairs(word): prev_char = char return pairs -def text_standardize(text): + +def lowercase_and_remove_accent(text): """ - fixes some issues the spacy tokenizer had on books corpus - also does some whitespace standardization + Lowercase and strips accents from a piece of text based on + https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py """ - text = text.replace('—', '-') - text = text.replace('–', '-') - text = text.replace('―', '-') + text = ' '.join(text) + text = text.lower() + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output).lower().split(' ') + + +def replace_unicode_punct(text): + ''' + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl + ''' + text = text.replace(',', ',') + text = re.sub(r'。\s*', '. ', text) + text = text.replace('、', ',') + text = text.replace('”', '"') + text = text.replace('“', '"') + text = text.replace('∶', ':') + text = text.replace(':', ':') + text = text.replace('?', '?') + text = text.replace('《', '"') + text = text.replace('》', '"') + text = text.replace(')', ')') + text = text.replace('!', '!') + text = text.replace('(', '(') + text = text.replace(';', ';') + text = text.replace('1', '"') + text = text.replace('」', '"') + text = text.replace('「', '"') + text = text.replace('0', '0') + text = text.replace('3', '3') + text = text.replace('2', '2') + text = text.replace('5', '5') + text = text.replace('6', '6') + text = text.replace('9', '9') + text = text.replace('7', '7') + text = text.replace('8', '8') + text = text.replace('4', '4') + text = re.sub(r'.\s*', '. ', text) + text = text.replace('~', '~') + text = text.replace('’', '\'') text = text.replace('…', '...') - text = text.replace('´', "'") - text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) - text = re.sub(r'\s*\n\s*', ' \n ', text) - text = re.sub(r'[^\S\n]+', ' ', text) - return text.strip() + text = text.replace('━', '-') + text = text.replace('〈', '<') + text = text.replace('〉', '>') + text = text.replace('【', '[') + text = text.replace('】', ']') + text = text.replace('%', '%') + return text + + +def remove_non_printing_char(text): + ''' + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl + ''' + output = [] + for char in text: + cat = unicodedata.category(char) + if cat.startswith('C'): + continue + output.append(char) + return "".join(output) + + +def romanian_preprocessing(text): + '''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`''' + # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py + text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219") + text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b") + # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py + text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma + text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma + text = text.replace("\u0102", "A").replace("\u0103", "a") + text = text.replace("\u00C2", "A").replace("\u00E2", "a") + text = text.replace("\u00CE", "I").replace("\u00EE", "i") + return text + class XLMTokenizer(PreTrainedTokenizer): """ - BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities: + BPE tokenizer for XLM - - lower case all inputs + - Moses preprocessing & tokenization for most supported languages - - uses `SpaCy tokenizer `_ and \ - `ftfy `_ for pre-BPE tokenization if they are installed, \ - fallback to BERT's BasicTokenizer if not. + - Language specific tokenization for Chinese (Jieba), Japanese (KyTea) and Thai (PyThaiNLP) + + - (optionally) lower case & normalize all inputs text - argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \ - (ex: "__classify__") to a vocabulary. + (ex: "__classify__") to a vocabulary + + - `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies) + + - `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies) + + - `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies) """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, merges_file, unk_token="", bos_token="", sep_token="", pad_token="", cls_token="", mask_token="", additional_special_tokens=["", "", "", "", "", "", - "", "", "", ""], **kwargs): + "", "", "", ""], + lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True, + **kwargs): super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, additional_special_tokens=additional_special_tokens, **kwargs) - self.max_len_single_sentence = self.max_len - 2 # take into account special tokens - self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens + # cache of sm.MosesPunctNormalizer instance + self.cache_moses_punct_normalizer = dict() + # cache of sm.MosesTokenizer instance + self.cache_moses_tokenizer = dict() + self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja']) + # True for current supported model (v1.2.0), False for XLM-17 & 100 + self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent + self.lang2id = lang2id + self.id2lang = id2lang + if lang2id is not None and id2lang is not None: + assert len(lang2id) == len(id2lang) - try: - import ftfy - from spacy.lang.en import English - _nlp = English() - self.nlp = _nlp.Defaults.create_tokenizer(_nlp) - self.fix_text = ftfy.fix_text - except ImportError: - logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") - self.nlp = BasicTokenizer(do_lower_case=True) - self.fix_text = None + self.ja_word_tokenizer = None + self.zh_word_tokenizer = None self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.decoder = {v:k for k,v in self.encoder.items()} @@ -144,6 +571,43 @@ class XLMTokenizer(PreTrainedTokenizer): self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} + def moses_punct_norm(self, text, lang): + if lang not in self.cache_moses_punct_normalizer: + punct_normalizer = sm.MosesPunctNormalizer(lang=lang) + self.cache_moses_punct_normalizer[lang] = punct_normalizer + else: + punct_normalizer = self.cache_moses_punct_normalizer[lang] + return punct_normalizer.normalize(text) + + def moses_tokenize(self, text, lang): + if lang not in self.cache_moses_tokenizer: + moses_tokenizer = sm.MosesTokenizer(lang=lang) + self.cache_moses_tokenizer[lang] = moses_tokenizer + else: + moses_tokenizer = self.cache_moses_tokenizer[lang] + return moses_tokenizer.tokenize(text, return_str=False, escape=False) + + def moses_pipeline(self, text, lang): + text = replace_unicode_punct(text) + text = self.moses_punct_norm(text, lang) + text = remove_non_printing_char(text) + return text + + def ja_tokenize(self, text): + if self.ja_word_tokenizer is None: + try: + import Mykytea + self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~')) + except (AttributeError, ImportError) as e: + logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps") + logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") + logger.error("2. autoreconf -i") + logger.error("3. ./configure --prefix=$HOME/local") + logger.error("4. make && make install") + logger.error("5. pip install kytea") + raise e + return list(self.ja_word_tokenizer.getWS(text)) + @property def vocab_size(self): return len(self.encoder) @@ -191,19 +655,86 @@ class XLMTokenizer(PreTrainedTokenizer): self.cache[token] = word return word - def _tokenize(self, text): - """ Tokenize a string. """ - split_tokens = [] - if self.fix_text is None: - # Using BERT's BasicTokenizer - text = self.nlp.tokenize(text) - for token in text: - split_tokens.extend([t for t in self.bpe(token).split(' ')]) + def _tokenize(self, text, lang='en', bypass_tokenizer=False): + """ + Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizerself. Otherwise, we use Moses. + + Details of tokenization: + - [sacremoses](https://github.com/alvations/sacremoses): port of Moses + - Install with `pip install sacremoses` + - [pythainlp](https://github.com/PyThaiNLP/pythainlp): Thai tokenizer + - Install with `pip install pythainlp` + - [kytea](https://github.com/chezou/Mykytea-python): Japanese tokenizer, wrapper of [KyTea](https://github.com/neubig/kytea) + - Install with the following steps: + ``` + git clone git@github.com:neubig/kytea.git && cd kytea + autoreconf -i + ./configure --prefix=$HOME/local + make && make install + pip install kytea + ``` + - [jieba](https://github.com/fxsjy/jieba): Chinese tokenizer * + - Install with `pip install jieba` + + \* The original XLM used [Stanford Segmenter](https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip). + However, the wrapper (`nltk.tokenize.stanford_segmenter`) is slow due to JVM overhead, and it will be deprecated. + Jieba is a lot faster and pip-installable. Note there is some mismatch with the Stanford Segmenter. It should be fine + if you fine-tune the model with Chinese supervisionself. If you want the same exact behaviour, use the original XLM + [preprocessing script](https://github.com/facebookresearch/XLM/tree/master/tools) to tokenize the sentence externally, + and set `bypass_tokenizer=True` to bypass the tokenizer. + + Args: + - lang: ISO language code (default = 'en') (string). Languages should belong of the model supported languages. However, we don't enforce it. + - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE. + + Returns: + List of tokens. + """ + if lang and self.lang2id and lang not in self.lang2id: + logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.") + if bypass_tokenizer: + text = text.split() + elif lang not in self.lang_with_custom_tokenizer: + text = self.moses_pipeline(text, lang=lang) + # TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step + if lang == 'ro': + text = romanian_preprocessing(text) + text = self.moses_tokenize(text, lang=lang) + elif lang == 'th': + text = self.moses_pipeline(text, lang=lang) + try: + if 'pythainlp' not in sys.modules: + from pythainlp.tokenize import word_tokenize as th_word_tokenize + except (AttributeError, ImportError) as e: + logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps") + logger.error("1. pip install pythainlp") + raise e + text = th_word_tokenize(text) + elif lang == 'zh': + try: + if 'jieba' not in sys.modules: + import jieba + except (AttributeError, ImportError) as e: + logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") + logger.error("1. pip install jieba") + raise e + text = ' '.join(jieba.cut(text)) + text = self.moses_pipeline(text, lang=lang) + text = text.split() + elif lang == 'ja': + text = self.moses_pipeline(text, lang=lang) + text = self.ja_tokenize(text) else: - # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) - text = self.nlp(text_standardize(self.fix_text(text))) - for token in text: - split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) + raise ValueError('It should not reach here') + + if self.do_lowercase_and_remove_accent and not bypass_tokenizer: + text = lowercase_and_remove_accent(text) + + split_tokens = [] + for token in text: + if token: + split_tokens.extend([t for t in self.bpe(token).split(' ')]) + return split_tokens def _convert_token_to_id(self, token): diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index ac7231bb680..bf9b9dc782f 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, max_len=None, + def __init__(self, vocab_file, do_lower_case=False, remove_space=True, keep_accents=False, bos_token="", eos_token="", unk_token="", sep_token="", pad_token="", cls_token="", mask_token="", diff --git a/requirements.txt b/requirements.txt index 76532d18a59..01dca79d23b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ requests # For OpenAI GPT regex # For XLNet -sentencepiece \ No newline at end of file +sentencepiece +# For XLM +sacremoses \ No newline at end of file diff --git a/setup.py b/setup.py index c9f80fc224a..29797222681 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,8 @@ setup( 'requests', 'tqdm', 'regex', - 'sentencepiece'], + 'sentencepiece', + 'sacremoses'], entry_points={ 'console_scripts': [ "pytorch_transformers=pytorch_transformers.__main__:main",