From 436ce072183e3e134d2fbc286f6c72f012f31e74 Mon Sep 17 00:00:00 2001 From: Shijie Wu Date: Fri, 23 Aug 2019 14:40:17 -0400 Subject: [PATCH 01/11] Tokenization behave the same as original XLM proprocessing for most languages except zh, ja and th; Change API to allow specifying language in `tokenize` --- pytorch_transformers/tokenization_xlm.py | 148 ++++++++++++++++++++--- requirements.txt | 4 +- setup.py | 3 +- 3 files changed, 135 insertions(+), 20 deletions(-) diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 2d2f3a8cd4d..8418a5d6f3b 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -20,8 +20,11 @@ import json import logging import os import re +import unicodedata from io import open +import sacremoses as sm + from .tokenization_utils import PreTrainedTokenizer from .tokenization_bert import BasicTokenizer @@ -95,6 +98,93 @@ def text_standardize(text): text = re.sub(r'[^\S\n]+', ' ', text) return text.strip() + +def lowercase_and_remove_accent(text): + """ + Lowercase and strips accents from a piece of text based on + https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py + """ + text = text.lower() + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output).lower() + + +def replace_unicode_punct(text): + ''' + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl + ''' + text = text.replace(',', ',') + text = text.replace('。 *', '. ') + text = text.replace('、', ',') + text = text.replace('”', '"') + text = text.replace('“', '"') + text = text.replace('∶', ':') + text = text.replace(':', ':') + text = text.replace('?', '?') + text = text.replace('《', '"') + text = text.replace('》', '"') + text = text.replace(')', ')') + text = text.replace('!', '!') + text = text.replace('(', '(') + text = text.replace(';', ';') + text = text.replace('1', '"') + text = text.replace('」', '"') + text = text.replace('「', '"') + text = text.replace('0', '0') + text = text.replace('3', '3') + text = text.replace('2', '2') + text = text.replace('5', '5') + text = text.replace('6', '6') + text = text.replace('9', '9') + text = text.replace('7', '7') + text = text.replace('8', '8') + text = text.replace('4', '4') + text = re.sub(r'.\s*', '. ', text) + text = text.replace('~', '~') + text = text.replace('’', '\'') + text = text.replace('…', '...') + text = text.replace('━', '-') + text = text.replace('〈', '<') + text = text.replace('〉', '>') + text = text.replace('【', '[') + text = text.replace('】', ']') + text = text.replace('%', '%') + return text + + +def remove_non_printing_char(text): + ''' + Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl + ''' + output = [] + for char in text: + cat = unicodedata.category(char) + if cat.startswith('C'): + continue + output.append(char) + return "".join(output) + + +def romanian_preprocessing(text): + '''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`''' + # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py + text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219") + text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b") + # https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py + text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma + text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma + text = text.replace("\u0102", "A").replace("\u0103", "a") + text = text.replace("\u00C2", "A").replace("\u00E2", "a") + text = text.replace("\u00CE", "I").replace("\u00EE", "i") + return text + + class XLMTokenizer(PreTrainedTokenizer): """ BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities: @@ -122,16 +212,14 @@ class XLMTokenizer(PreTrainedTokenizer): cls_token=cls_token, mask_token=mask_token, additional_special_tokens=additional_special_tokens, **kwargs) - try: - import ftfy - from spacy.lang.en import English - _nlp = English() - self.nlp = _nlp.Defaults.create_tokenizer(_nlp) - self.fix_text = ftfy.fix_text - except ImportError: - logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") - self.nlp = BasicTokenizer(do_lower_case=True) - self.fix_text = None + + # cache of sm.MosesPunctNormalizer instance + self.cache_moses_punct_normalizer = dict() + # cache of sm.MosesTokenizer instance + self.cache_moses_tokenizer = dict() + self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja']) + # True for current supported model (v1.2.0), False for XLM-17 & 100 + self.do_lowercase_and_remove_accent = True self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.decoder = {v:k for k,v in self.encoder.items()} @@ -140,6 +228,28 @@ class XLMTokenizer(PreTrainedTokenizer): self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} + def moses_punct_norm(self, text, lang): + if lang not in self.cache_moses_punct_normalizer: + punct_normalizer = sm.MosesPunctNormalizer(lang=lang) + self.cache_moses_punct_normalizer[lang] = punct_normalizer + else: + punct_normalizer = self.cache_moses_punct_normalizer[lang] + return punct_normalizer.normalize(text) + + def moses_tokenize(self, text, lang): + if lang not in self.cache_moses_tokenizer: + moses_tokenizer = sm.MosesTokenizer(lang=lang) + self.cache_moses_tokenizer[lang] = moses_tokenizer + else: + moses_tokenizer = self.cache_moses_tokenizer[lang] + return moses_tokenizer.tokenize(text, return_str=False, escape=False) + + def moses_pipeline(self, text, lang): + text = replace_unicode_punct(text) + text = self.moses_punct_norm(text, lang) + text = remove_non_printing_char(text) + return text + @property def vocab_size(self): return len(self.encoder) @@ -187,19 +297,21 @@ class XLMTokenizer(PreTrainedTokenizer): self.cache[token] = word return word - def _tokenize(self, text): + def _tokenize(self, text, lang='en'): """ Tokenize a string. """ split_tokens = [] - if self.fix_text is None: - # Using BERT's BasicTokenizer - text = self.nlp.tokenize(text) + if self.do_lowercase_and_remove_accent: + text = lowercase_and_remove_accent(text) + if lang not in self.lang_with_custom_tokenizer: + text = self.moses_pipeline(text, lang=lang) + # TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step + if lang == 'ro': + text = romanian_preprocessing(text) + text = self.moses_tokenize(text, lang=lang) for token in text: split_tokens.extend([t for t in self.bpe(token).split(' ')]) else: - # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) - text = self.nlp(text_standardize(self.fix_text(text))) - for token in text: - split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) + raise ValueError return split_tokens def _convert_token_to_id(self, token): diff --git a/requirements.txt b/requirements.txt index 76532d18a59..01dca79d23b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ requests # For OpenAI GPT regex # For XLNet -sentencepiece \ No newline at end of file +sentencepiece +# For XLM +sacremoses \ No newline at end of file diff --git a/setup.py b/setup.py index c9f80fc224a..29797222681 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,8 @@ setup( 'requests', 'tqdm', 'regex', - 'sentencepiece'], + 'sentencepiece', + 'sacremoses'], entry_points={ 'console_scripts': [ "pytorch_transformers=pytorch_transformers.__main__:main", From e85123d398bfc2e58f6f6539e524ee9c4619ec0d Mon Sep 17 00:00:00 2001 From: Shijie Wu Date: Fri, 23 Aug 2019 20:27:52 -0400 Subject: [PATCH 02/11] Add custom tokenizer for zh and ja --- pytorch_transformers/tokenization_xlm.py | 71 +++++++++++++++++------- requirements.txt | 6 +- setup.py | 6 +- 3 files changed, 61 insertions(+), 22 deletions(-) diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 8418a5d6f3b..a459dea9b91 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -23,7 +23,11 @@ import re import unicodedata from io import open +import jieba +import Mykytea import sacremoses as sm +from nltk.tokenize.stanford_segmenter import StanfordSegmenter +from pythainlp.tokenize import word_tokenize as th_word_tokenize from .tokenization_utils import PreTrainedTokenizer from .tokenization_bert import BasicTokenizer @@ -83,21 +87,6 @@ def get_pairs(word): prev_char = char return pairs -def text_standardize(text): - """ - fixes some issues the spacy tokenizer had on books corpus - also does some whitespace standardization - """ - text = text.replace('—', '-') - text = text.replace('–', '-') - text = text.replace('―', '-') - text = text.replace('…', '...') - text = text.replace('´', "'") - text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) - text = re.sub(r'\s*\n\s*', ' \n ', text) - text = re.sub(r'[^\S\n]+', ' ', text) - return text.strip() - def lowercase_and_remove_accent(text): """ @@ -120,7 +109,7 @@ def replace_unicode_punct(text): Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl ''' text = text.replace(',', ',') - text = text.replace('。 *', '. ') + text = re.sub(r'。\s*', '. ', text) text = text.replace('、', ',') text = text.replace('”', '"') text = text.replace('“', '"') @@ -220,6 +209,8 @@ class XLMTokenizer(PreTrainedTokenizer): self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja']) # True for current supported model (v1.2.0), False for XLM-17 & 100 self.do_lowercase_and_remove_accent = True + self.ja_word_tokenizer = None + self.zh_word_tokenizer = None self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.decoder = {v:k for k,v in self.encoder.items()} @@ -250,6 +241,33 @@ class XLMTokenizer(PreTrainedTokenizer): text = remove_non_printing_char(text) return text + def ja_tokenize(self, text): + if self.ja_word_tokenizer is None: + try: + self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~')) + except RuntimeError: + logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) with the following steps") + logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") + logger.error("2. autoreconf -i") + logger.error("3. ./configure --prefix=$HOME/local") + logger.error("4. make && make install") + import sys; sys.exit() + return list(self.ja_word_tokenizer.getWS(text)) + + def zh_tokenize(self, text): + if self.zh_word_tokenizer is None: + try: + self.zh_word_tokenizer = StanfordSegmenter() + self.zh_word_tokenizer.default_config('zh') + except LookupError: + logger.error("Make sure you download stanford-segmenter (https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip) with the following steps") + logger.error("1. wget https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip -O /path/to/stanford-segmenter-2018-10-16.zip") + logger.error("2. cd /path/to && unzip stanford-segmenter-2018-10-16.zip") + logger.error("3. cd stanford-segmenter-2018-10-16 && cp stanford-segmenter-3.9.2.jar stanford-segmenter.jar") + logger.error("4. set env variable STANFORD_SEGMENTER=/path/to/stanford-segmenter-2018-10-16") + import sys; sys.exit() + return self.zh_word_tokenizer.segment(text) + @property def vocab_size(self): return len(self.encoder) @@ -299,7 +317,6 @@ class XLMTokenizer(PreTrainedTokenizer): def _tokenize(self, text, lang='en'): """ Tokenize a string. """ - split_tokens = [] if self.do_lowercase_and_remove_accent: text = lowercase_and_remove_accent(text) if lang not in self.lang_with_custom_tokenizer: @@ -308,10 +325,24 @@ class XLMTokenizer(PreTrainedTokenizer): if lang == 'ro': text = romanian_preprocessing(text) text = self.moses_tokenize(text, lang=lang) - for token in text: - split_tokens.extend([t for t in self.bpe(token).split(' ')]) + elif lang == 'th': + text = self.moses_pipeline(text, lang=lang) + text = th_word_tokenize(text) + elif lang == 'zh': + # text = self.zh_tokenize(text) + text = ' '.join(jieba.cut(text)) + text = self.moses_pipeline(text, lang=lang) + text = text.split() + elif lang == 'ja': + text = self.moses_pipeline(text, lang=lang) + text = self.ja_tokenize(text) else: - raise ValueError + raise ValueError('It should not reach here') + + split_tokens = [] + for token in text: + split_tokens.extend([t for t in self.bpe(token).split(' ')]) + return split_tokens def _convert_token_to_id(self, token): diff --git a/requirements.txt b/requirements.txt index 01dca79d23b..2e3f8ace51d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,8 @@ regex # For XLNet sentencepiece # For XLM -sacremoses \ No newline at end of file +sacremoses +pythainlp +kytea +nltk +jieba \ No newline at end of file diff --git a/setup.py b/setup.py index 29797222681..e37e948fb4d 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,11 @@ setup( 'tqdm', 'regex', 'sentencepiece', - 'sacremoses'], + 'sacremoses', + 'pythainlp', + 'kytea', + 'nltk', + 'jieba'], entry_points={ 'console_scripts': [ "pytorch_transformers=pytorch_transformers.__main__:main", From f1b018740c9355f0bcf0093fc993724eaa737445 Mon Sep 17 00:00:00 2001 From: Shijie Wu Date: Fri, 23 Aug 2019 20:33:01 -0400 Subject: [PATCH 03/11] Add use_lang_emb to config --- pytorch_transformers/modeling_xlm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 19800da2edf..10be972ea53 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -114,6 +114,7 @@ class XLMConfig(PretrainedConfig): causal=False, asm=False, n_langs=1, + use_lang_emb=True, max_position_embeddings=512, embed_init_std=2048 ** -0.5, layer_norm_eps=1e-12, @@ -157,6 +158,7 @@ class XLMConfig(PretrainedConfig): self.causal = causal self.asm = asm self.n_langs = n_langs + self.use_lang_emb = use_lang_emb self.layer_norm_eps = layer_norm_eps self.bos_index = bos_index self.eos_index = eos_index @@ -488,7 +490,7 @@ class XLMModel(XLMPreTrainedModel): """ ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output', - 'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads', + 'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads', 'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'asm_cutoffs', 'asm_div_value'] @@ -507,6 +509,7 @@ class XLMModel(XLMPreTrainedModel): # dictionary / languages self.n_langs = config.n_langs + self.use_lang_emb = config.use_lang_emb self.n_words = config.n_words self.eos_index = config.eos_index self.pad_index = config.pad_index @@ -529,7 +532,7 @@ class XLMModel(XLMPreTrainedModel): self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim) if config.sinusoidal_embeddings: create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) - if config.n_langs > 1: + if config.n_langs > 1 and config.use_lang_emb: self.lang_embeddings = nn.Embedding(self.n_langs, self.dim) self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index) self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps) @@ -628,7 +631,7 @@ class XLMModel(XLMPreTrainedModel): # embeddings tensor = self.embeddings(input_ids) tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor) - if langs is not None: + if langs is not None and self.use_lang_emb: tensor = tensor + self.lang_embeddings(langs) if token_type_ids is not None: tensor = tensor + self.embeddings(token_type_ids) From a175a9dc0188a367400c2121391fa3abf536748e Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 27 Aug 2019 14:05:59 +0200 Subject: [PATCH 04/11] add kwargs to base encode function --- pytorch_transformers/tokenization_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 4fef0e34fb0..1d054415935 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -563,7 +563,7 @@ class PreTrainedTokenizer(object): def _convert_token_to_id(self, token): raise NotImplementedError - def encode(self, text, text_pair=None, add_special_tokens=False): + def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. @@ -574,15 +574,16 @@ class PreTrainedTokenizer(object): text_pair: Optional second sequence to be encoded. add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative to their model. + **kwargs: passed to the `self.tokenize()` method """ if text_pair is None: if add_special_tokens: - return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text))) + return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs))) else: - return self.convert_tokens_to_ids(self.tokenize(text)) + return self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) - first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)] - second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair)] + first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] + second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if add_special_tokens: return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) From ca4baf8ca1f09e379c5e396c3332ff570f4422fc Mon Sep 17 00:00:00 2001 From: Shijie Wu Date: Tue, 27 Aug 2019 20:03:18 -0400 Subject: [PATCH 05/11] Match order of casing in OSS XLM; Improve document; Clean up dependency --- pytorch_transformers/tokenization_xlm.py | 102 +++++++++++++++-------- requirements.txt | 6 +- setup.py | 6 +- 3 files changed, 71 insertions(+), 43 deletions(-) diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index a459dea9b91..71bf1193873 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -20,14 +20,11 @@ import json import logging import os import re +import sys import unicodedata from io import open -import jieba -import Mykytea import sacremoses as sm -from nltk.tokenize.stanford_segmenter import StanfordSegmenter -from pythainlp.tokenize import word_tokenize as th_word_tokenize from .tokenization_utils import PreTrainedTokenizer from .tokenization_bert import BasicTokenizer @@ -93,6 +90,7 @@ def lowercase_and_remove_accent(text): Lowercase and strips accents from a piece of text based on https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py """ + text = ' '.join(text) text = text.lower() text = unicodedata.normalize("NFD", text) output = [] @@ -101,7 +99,7 @@ def lowercase_and_remove_accent(text): if cat == "Mn": continue output.append(char) - return "".join(output).lower() + return "".join(output).lower().split(' ') def replace_unicode_punct(text): @@ -176,13 +174,13 @@ def romanian_preprocessing(text): class XLMTokenizer(PreTrainedTokenizer): """ - BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities: + BPE tokenizer for XLM - - lower case all inputs + - Moses preprocessing & tokenization for most supported languages - - uses `SpaCy tokenizer `_ and \ - `ftfy `_ for pre-BPE tokenization if they are installed, \ - fallback to BERT's BasicTokenizer if not. + - Language specific tokenization for Chinese (Jieba), Japanese (KyTea) and Thai (PyThaiNLP) + + - (optionally) lower case & normalize all inputs text - argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \ (ex: "__classify__") to a vocabulary. @@ -244,30 +242,18 @@ class XLMTokenizer(PreTrainedTokenizer): def ja_tokenize(self, text): if self.ja_word_tokenizer is None: try: + import Mykytea self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~')) - except RuntimeError: - logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) with the following steps") + except: + logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps") logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") logger.error("2. autoreconf -i") logger.error("3. ./configure --prefix=$HOME/local") logger.error("4. make && make install") + logger.error("5. pip install kytea") import sys; sys.exit() return list(self.ja_word_tokenizer.getWS(text)) - def zh_tokenize(self, text): - if self.zh_word_tokenizer is None: - try: - self.zh_word_tokenizer = StanfordSegmenter() - self.zh_word_tokenizer.default_config('zh') - except LookupError: - logger.error("Make sure you download stanford-segmenter (https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip) with the following steps") - logger.error("1. wget https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip -O /path/to/stanford-segmenter-2018-10-16.zip") - logger.error("2. cd /path/to && unzip stanford-segmenter-2018-10-16.zip") - logger.error("3. cd stanford-segmenter-2018-10-16 && cp stanford-segmenter-3.9.2.jar stanford-segmenter.jar") - logger.error("4. set env variable STANFORD_SEGMENTER=/path/to/stanford-segmenter-2018-10-16") - import sys; sys.exit() - return self.zh_word_tokenizer.segment(text) - @property def vocab_size(self): return len(self.encoder) @@ -315,11 +301,44 @@ class XLMTokenizer(PreTrainedTokenizer): self.cache[token] = word return word - def _tokenize(self, text, lang='en'): - """ Tokenize a string. """ - if self.do_lowercase_and_remove_accent: - text = lowercase_and_remove_accent(text) - if lang not in self.lang_with_custom_tokenizer: + def _tokenize(self, text, lang='en', bypass_tokenizer=False): + """ + Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizerself. Otherwise, we use Moses. + + Details of tokenization: + - [sacremoses](https://github.com/alvations/sacremoses): port of Moses + - Install with `pip install sacremoses` + - [pythainlp](https://github.com/PyThaiNLP/pythainlp): Thai tokenizer + - Install with `pip install pythainlp` + - [kytea](https://github.com/chezou/Mykytea-python): Japanese tokenizer, wrapper of [KyTea](https://github.com/neubig/kytea) + - Install with the following steps: + ``` + git clone git@github.com:neubig/kytea.git && cd kytea + autoreconf -i + ./configure --prefix=$HOME/local + make && make install + pip install kytea + ``` + - [jieba](https://github.com/fxsjy/jieba): Chinese tokenizer * + - Install with `pip install jieba` + + \* The original XLM used [Stanford Segmenter](https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip). + However, the wrapper (`nltk.tokenize.stanford_segmenter`) is slow due to JVM overhead, and it will be deprecated. + Jieba is a lot faster and pip-installable. Note there is some mismatch with the Stanford Segmenter. It should be fine + if you fine-tune the model with Chinese supervisionself. If you want the same exact behaviour, use the original XLM + [preprocessing script](https://github.com/facebookresearch/XLM/tree/master/tools) to tokenize the sentence externally, + and set `bypass_tokenizer=True` to bypass the tokenizer. + + Args: + - lang: ISO language code (default = 'en') (string). Languages should belong of the model supported languages. However, we don't enforce it. + - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE. + + Returns: + List of tokens. + """ + if bypass_tokenizer: + text = text.split() + elif lang not in self.lang_with_custom_tokenizer: text = self.moses_pipeline(text, lang=lang) # TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step if lang == 'ro': @@ -327,9 +346,22 @@ class XLMTokenizer(PreTrainedTokenizer): text = self.moses_tokenize(text, lang=lang) elif lang == 'th': text = self.moses_pipeline(text, lang=lang) + try: + if 'pythainlp' not in sys.modules: + from pythainlp.tokenize import word_tokenize as th_word_tokenize + except: + logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps") + logger.error("1. pip install pythainlp") + import sys; sys.exit() text = th_word_tokenize(text) elif lang == 'zh': - # text = self.zh_tokenize(text) + try: + if 'jieba' not in sys.modules: + import jieba + except: + logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") + logger.error("1. pip install jieba") + import sys; sys.exit() text = ' '.join(jieba.cut(text)) text = self.moses_pipeline(text, lang=lang) text = text.split() @@ -339,9 +371,13 @@ class XLMTokenizer(PreTrainedTokenizer): else: raise ValueError('It should not reach here') + if self.do_lowercase_and_remove_accent and not bypass_tokenizer: + text = lowercase_and_remove_accent(text) + split_tokens = [] for token in text: - split_tokens.extend([t for t in self.bpe(token).split(' ')]) + if token: + split_tokens.extend([t for t in self.bpe(token).split(' ')]) return split_tokens diff --git a/requirements.txt b/requirements.txt index 2e3f8ace51d..01dca79d23b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,8 +11,4 @@ regex # For XLNet sentencepiece # For XLM -sacremoses -pythainlp -kytea -nltk -jieba \ No newline at end of file +sacremoses \ No newline at end of file diff --git a/setup.py b/setup.py index e37e948fb4d..29797222681 100644 --- a/setup.py +++ b/setup.py @@ -56,11 +56,7 @@ setup( 'tqdm', 'regex', 'sentencepiece', - 'sacremoses', - 'pythainlp', - 'kytea', - 'nltk', - 'jieba'], + 'sacremoses'], entry_points={ 'console_scripts': [ "pytorch_transformers=pytorch_transformers.__main__:main", From 82462c5cba0ec07a3eeb1e9455d229ceaf43b5f2 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 15:30:41 +0200 Subject: [PATCH 06/11] Added option to setup pretrained tokenizer arguments --- pytorch_transformers/tokenization_bert.py | 36 +++--- pytorch_transformers/tokenization_utils.py | 23 ++-- pytorch_transformers/tokenization_xlm.py | 135 +++++++++++++++++++-- 3 files changed, 159 insertions(+), 35 deletions(-) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index 04f35aa4662..d1ace940f06 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'bert-base-cased-finetuned-mrpc': 512, } +PRETRAINED_INIT_CONFIGURATION = { + 'bert-base-uncased': {'do_lower_case': True}, + 'bert-large-uncased': {'do_lower_case': True}, + 'bert-base-cased': {'do_lower_case': False}, + 'bert-large-cased': {'do_lower_case': False}, + 'bert-base-multilingual-uncased': {'do_lower_case': True}, + 'bert-base-multilingual-cased': {'do_lower_case': False}, + 'bert-base-chinese': {'do_lower_case': False}, + 'bert-base-german-cased': {'do_lower_case': False}, + 'bert-large-uncased-whole-word-masking': {'do_lower_case': True}, + 'bert-large-cased-whole-word-masking': {'do_lower_case': False}, + 'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True}, + 'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False}, + 'bert-base-cased-finetuned-mrpc': {'do_lower_case': False}, +} + + def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() @@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, @@ -199,24 +217,6 @@ class BertTokenizer(PreTrainedTokenizer): index += 1 return (vocab_file,) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): - """ Instantiate a BertTokenizer from pre-trained vocabulary files. - """ - if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: - if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is a cased model but you have not set " - "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " - "you may want to check this behavior.") - kwargs['do_lower_case'] = False - elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): - logger.warning("The pre-trained model you are loading is an uncased model but you have set " - "`do_lower_case` to False. We are setting `do_lower_case=True` for you " - "but you may want to check this behavior.") - kwargs['do_lower_case'] = True - - return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - class BasicTokenizer(object): """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 1d054415935..19b37da8c8b 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -40,6 +40,7 @@ class PreTrainedTokenizer(object): - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file. - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size. + - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method. Parameters: @@ -61,6 +62,7 @@ class PreTrainedTokenizer(object): """ vocab_files_names = {} pretrained_vocab_files_map = {} + pretrained_init_configuration = {} max_model_input_sizes = {} SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", @@ -235,10 +237,13 @@ class PreTrainedTokenizer(object): s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} + init_configuration = {} if pretrained_model_name_or_path in s3_models: # Get the vocabulary from AWS S3 bucket for file_id, map_list in cls.pretrained_vocab_files_map.items(): vocab_files[file_id] = map_list[pretrained_model_name_or_path] + if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration: + init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path] else: # Get the vocabulary from local files logger.info( @@ -312,28 +317,32 @@ class PreTrainedTokenizer(object): logger.info("loading file {} from cache at {}".format( file_path, resolved_vocab_files[file_id])) + # Prepare initialization kwargs + init_kwargs = init_configuration + init_kwargs.update(kwargs) + # Set max length if needed if pretrained_model_name_or_path in cls.max_model_input_sizes: # if we're using a pretrained model, ensure the tokenizer # wont index sequences longer than the number of positional embeddings max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] if max_len is not None and isinstance(max_len, (int, float)): - kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len) - # Merge resolved_vocab_files arguments in kwargs. + # Merge resolved_vocab_files arguments in init_kwargs. added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) for args_name, file_path in resolved_vocab_files.items(): - if args_name not in kwargs: - kwargs[args_name] = file_path + if args_name not in init_kwargs: + init_kwargs[args_name] = file_path if special_tokens_map_file is not None: special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8")) for key, value in special_tokens_map.items(): - if key not in kwargs: - kwargs[key] = value + if key not in init_kwargs: + init_kwargs[key] = value # Instantiate tokenizer. - tokenizer = cls(*inputs, **kwargs) + tokenizer = cls(*inputs, **init_kwargs) # Add supplementary tokens. if added_tokens_file is not None: diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 71bf1193873..c40d4cd16e7 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -47,7 +47,9 @@ PRETRAINED_VOCAB_FILES_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json", - }, + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", + } 'merges_file': { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", @@ -58,6 +60,8 @@ PRETRAINED_VOCAB_FILES_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt", }, } @@ -70,6 +74,101 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'xlm-mlm-xnli15-1024': 512, 'xlm-clm-enfr-1024': 512, 'xlm-clm-ende-1024': 512, + 'xlm-mlm-17-1280': 512, + 'xlm-mlm-100-1280': 512, +} + +PRETRAINED_INIT_CONFIGURATION = { + 'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True}, + 'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "de", + "1": "en"}, + "lang2id": { "de": 0, + "en": 1 }}, + 'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "en", + "1": "fr"}, + "lang2id": { "en": 0, + "fr": 1 }}, + 'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "en", + "1": "ro"}, + "lang2id": { "en": 0, + "ro": 1 }}, + 'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "ar", + "1": "bg", + "2": "de", + "3": "el", + "4": "en", + "5": "es", + "6": "fr", + "7": "hi", + "8": "ru", + "9": "sw", + "10": "th", + "11": "tr", + "12": "ur", + "13": "vi", + "14": "zh"}, + "lang2id": { "ar": 0, + "bg": 1, + "de": 2, + "el": 3, + "en": 4, + "es": 5, + "fr": 6, + "hi": 7, + "ru": 8, + "sw": 9, + "th": 10, + "tr": 11, + "ur": 12, + "vi": 13, + "zh": 14 }}, + 'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "ar", + "1": "bg", + "2": "de", + "3": "el", + "4": "en", + "5": "es", + "6": "fr", + "7": "hi", + "8": "ru", + "9": "sw", + "10": "th", + "11": "tr", + "12": "ur", + "13": "vi", + "14": "zh"}, + "lang2id": { "ar": 0, + "bg": 1, + "de": 2, + "el": 3, + "en": 4, + "es": 5, + "fr": 6, + "hi": 7, + "ru": 8, + "sw": 9, + "th": 10, + "tr": 11, + "ur": 12, + "vi": 13, + "zh": 14 }}, + 'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "en", + "1": "fr"}, + "lang2id": { "en": 0, + "fr": 1 }}, + 'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True, + "id2lang": { "0": "de", + "1": "en"}, + "lang2id": { "de": 0, + "en": 1 }}, + 'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False}, + 'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False}, } def get_pairs(word): @@ -183,17 +282,26 @@ class XLMTokenizer(PreTrainedTokenizer): - (optionally) lower case & normalize all inputs text - argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \ - (ex: "__classify__") to a vocabulary. + (ex: "__classify__") to a vocabulary + + - `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies) + + - `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies) + + - `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies) """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__(self, vocab_file, merges_file, unk_token="", bos_token="", sep_token="", pad_token="", cls_token="", mask_token="", additional_special_tokens=["", "", "", "", "", "", - "", "", "", ""], **kwargs): + "", "", "", ""], + lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True, + **kwargs): super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, @@ -206,7 +314,12 @@ class XLMTokenizer(PreTrainedTokenizer): self.cache_moses_tokenizer = dict() self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja']) # True for current supported model (v1.2.0), False for XLM-17 & 100 - self.do_lowercase_and_remove_accent = True + self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent + self.lang2id = lang2id + self.id2lang = id2lang + if lang2id is not None and id2lang is not None: + assert len(lang2id) == len(id2lang) + self.ja_word_tokenizer = None self.zh_word_tokenizer = None @@ -244,14 +357,14 @@ class XLMTokenizer(PreTrainedTokenizer): try: import Mykytea self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~')) - except: + except (AttributeError, ImportError) as e: logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps") logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea") logger.error("2. autoreconf -i") logger.error("3. ./configure --prefix=$HOME/local") logger.error("4. make && make install") logger.error("5. pip install kytea") - import sys; sys.exit() + raise e return list(self.ja_word_tokenizer.getWS(text)) @property @@ -336,6 +449,8 @@ class XLMTokenizer(PreTrainedTokenizer): Returns: List of tokens. """ + if lang and self.lang2id and lang not in self.lang2id: + logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.") if bypass_tokenizer: text = text.split() elif lang not in self.lang_with_custom_tokenizer: @@ -349,19 +464,19 @@ class XLMTokenizer(PreTrainedTokenizer): try: if 'pythainlp' not in sys.modules: from pythainlp.tokenize import word_tokenize as th_word_tokenize - except: + except (AttributeError, ImportError) as e: logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps") logger.error("1. pip install pythainlp") - import sys; sys.exit() + raise e text = th_word_tokenize(text) elif lang == 'zh': try: if 'jieba' not in sys.modules: import jieba - except: + except (AttributeError, ImportError) as e: logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps") logger.error("1. pip install jieba") - import sys; sys.exit() + raise e text = ' '.join(jieba.cut(text)) text = self.moses_pipeline(text, lang=lang) text = text.split() From 8678ff8df5cc9997537fb62251ba91e58eefc0ec Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 16:26:04 +0200 Subject: [PATCH 07/11] adding 17 and 100 xlm models --- pytorch_transformers/tokenization_xlm.py | 247 ++++++++++++++++++++++- 1 file changed, 244 insertions(+), 3 deletions(-) diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index c40d4cd16e7..d14acb39c63 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -49,7 +49,7 @@ PRETRAINED_VOCAB_FILES_MAP = { 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json", 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json", - } + }, 'merges_file': { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", @@ -167,8 +167,249 @@ PRETRAINED_INIT_CONFIGURATION = { "1": "en"}, "lang2id": { "de": 0, "en": 1 }}, - 'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False}, - 'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False}, + 'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False, + "id2lang": { + "0": "ar", + "1": "de", + "2": "en", + "3": "es", + "4": "fr", + "5": "hi", + "6": "it", + "7": "ja", + "8": "ko", + "9": "nl", + "10": "pl", + "11": "pt", + "12": "ru", + "13": "sv", + "14": "tr", + "15": "vi", + "16": "zh" + }, + "lang2id": { + "ar": 0, + "de": 1, + "en": 2, + "es": 3, + "fr": 4, + "hi": 5, + "it": 6, + "ja": 7, + "ko": 8, + "nl": 9, + "pl": 10, + "pt": 11, + "ru": 12, + "sv": 13, + "tr": 14, + "vi": 15, + "zh": 16}}, + 'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False, + "id2lang": { + "0": "af", + "1": "als", + "2": "am", + "3": "an", + "4": "ang", + "5": "ar", + "6": "arz", + "7": "ast", + "8": "az", + "9": "bar", + "10": "be", + "11": "bg", + "12": "bn", + "13": "br", + "14": "bs", + "15": "ca", + "16": "ceb", + "17": "ckb", + "18": "cs", + "19": "cy", + "20": "da", + "21": "de", + "22": "el", + "23": "en", + "24": "eo", + "25": "es", + "26": "et", + "27": "eu", + "28": "fa", + "29": "fi", + "30": "fr", + "31": "fy", + "32": "ga", + "33": "gan", + "34": "gl", + "35": "gu", + "36": "he", + "37": "hi", + "38": "hr", + "39": "hu", + "40": "hy", + "41": "ia", + "42": "id", + "43": "is", + "44": "it", + "45": "ja", + "46": "jv", + "47": "ka", + "48": "kk", + "49": "kn", + "50": "ko", + "51": "ku", + "52": "la", + "53": "lb", + "54": "lt", + "55": "lv", + "56": "mk", + "57": "ml", + "58": "mn", + "59": "mr", + "60": "ms", + "61": "my", + "62": "nds", + "63": "ne", + "64": "nl", + "65": "nn", + "66": "no", + "67": "oc", + "68": "pl", + "69": "pt", + "70": "ro", + "71": "ru", + "72": "scn", + "73": "sco", + "74": "sh", + "75": "si", + "76": "simple", + "77": "sk", + "78": "sl", + "79": "sq", + "80": "sr", + "81": "sv", + "82": "sw", + "83": "ta", + "84": "te", + "85": "th", + "86": "tl", + "87": "tr", + "88": "tt", + "89": "uk", + "90": "ur", + "91": "uz", + "92": "vi", + "93": "war", + "94": "wuu", + "95": "yi", + "96": "zh", + "97": "zh_classical", + "98": "zh_min_nan", + "99": "zh_yue" + }, + "lang2id": { + "af": 0, + "als": 1, + "am": 2, + "an": 3, + "ang": 4, + "ar": 5, + "arz": 6, + "ast": 7, + "az": 8, + "bar": 9, + "be": 10, + "bg": 11, + "bn": 12, + "br": 13, + "bs": 14, + "ca": 15, + "ceb": 16, + "ckb": 17, + "cs": 18, + "cy": 19, + "da": 20, + "de": 21, + "el": 22, + "en": 23, + "eo": 24, + "es": 25, + "et": 26, + "eu": 27, + "fa": 28, + "fi": 29, + "fr": 30, + "fy": 31, + "ga": 32, + "gan": 33, + "gl": 34, + "gu": 35, + "he": 36, + "hi": 37, + "hr": 38, + "hu": 39, + "hy": 40, + "ia": 41, + "id": 42, + "is": 43, + "it": 44, + "ja": 45, + "jv": 46, + "ka": 47, + "kk": 48, + "kn": 49, + "ko": 50, + "ku": 51, + "la": 52, + "lb": 53, + "lt": 54, + "lv": 55, + "mk": 56, + "ml": 57, + "mn": 58, + "mr": 59, + "ms": 60, + "my": 61, + "nds": 62, + "ne": 63, + "nl": 64, + "nn": 65, + "no": 66, + "oc": 67, + "pl": 68, + "pt": 69, + "ro": 70, + "ru": 71, + "scn": 72, + "sco": 73, + "sh": 74, + "si": 75, + "simple": 76, + "sk": 77, + "sl": 78, + "sq": 79, + "sr": 80, + "sv": 81, + "sw": 82, + "ta": 83, + "te": 84, + "th": 85, + "tl": 86, + "tr": 87, + "tt": 88, + "uk": 89, + "ur": 90, + "uz": 91, + "vi": 92, + "war": 93, + "wuu": 94, + "yi": 95, + "zh": 96, + "zh_classical": 97, + "zh_min_nan": 98, + "zh_yue": 99 + }}, } def get_pairs(word): From 3871b8a10757c1b67b29324e872f5b865e49c86c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 16:28:42 +0200 Subject: [PATCH 08/11] adding xlm 17 and 100 models and config on aws --- pytorch_transformers/modeling_xlm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 10be972ea53..d82d45fc27d 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -44,6 +44,8 @@ XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin", + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json", } XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", @@ -54,6 +56,8 @@ XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", } From 88111de07c40797aaca619be693616c3c4cda4bd Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 16:55:48 +0200 Subject: [PATCH 09/11] saving and reloading tokenizer configurations --- pytorch_transformers/tokenization_utils.py | 54 ++++++++++++++++++---- 1 file changed, 45 insertions(+), 9 deletions(-) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 19b37da8c8b..51e59fe46c3 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -20,6 +20,7 @@ import logging import os import json import six +import copy from io import open from .file_utils import cached_path @@ -28,6 +29,7 @@ logger = logging.getLogger(__name__) SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' ADDED_TOKENS_FILE = 'added_tokens.json' +TOKENIZER_CONFIG_FILE = 'tokenizer_config.json' class PreTrainedTokenizer(object): """ Base class for all tokenizers. @@ -168,9 +170,15 @@ class PreTrainedTokenizer(object): self._additional_special_tokens = [] self.max_len = max_len if max_len is not None else int(1e12) + + # Added tokens self.added_tokens_encoder = {} self.added_tokens_decoder = {} + # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) + self.init_inputs = () + self.init_kwargs = {} + for key, value in kwargs.items(): if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key == 'additional_special_tokens': @@ -230,7 +238,7 @@ class PreTrainedTokenizer(object): @classmethod - def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): cache_dir = kwargs.pop('cache_dir', None) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) @@ -266,15 +274,17 @@ class PreTrainedTokenizer(object): vocab_files[file_id] = full_file_name # Look for the additional tokens files - all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, - 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE} + additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, + 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE, + 'tokenizer_config_file': TOKENIZER_CONFIG_FILE, + } # If a path to a file was provided, get the parent directory saved_directory = pretrained_model_name_or_path if os.path.exists(saved_directory) and not os.path.isdir(saved_directory): saved_directory = os.path.dirname(saved_directory) - for file_id, file_name in all_vocab_files_names.items(): + for file_id, file_name in additional_files_names.items(): full_file_name = os.path.join(saved_directory, file_name) if not os.path.exists(full_file_name): logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) @@ -317,8 +327,18 @@ class PreTrainedTokenizer(object): logger.info("loading file {} from cache at {}".format( file_path, resolved_vocab_files[file_id])) - # Prepare initialization kwargs - init_kwargs = init_configuration + # Prepare tokenizer initialization kwargs + # Did we saved some inputs and kwargs to reload ? + tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None) + if tokenizer_config_file is not None: + init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8")) + saved_init_inputs = init_kwargs.pop('init_inputs', []) + if not init_inputs: + init_inputs = saved_init_inputs + else: + init_kwargs = init_configuration + + # Update with newly provided kwargs init_kwargs.update(kwargs) # Set max length if needed @@ -342,7 +362,11 @@ class PreTrainedTokenizer(object): init_kwargs[key] = value # Instantiate tokenizer. - tokenizer = cls(*inputs, **init_kwargs) + tokenizer = cls(*init_inputs, **init_kwargs) + + # Save inputs and kwargs for saving and re-loading with ``save_pretrained`` + tokenizer.init_inputs = init_inputs + tokenizer.init_kwargs = init_kwargs # Add supplementary tokens. if added_tokens_file is not None: @@ -355,8 +379,13 @@ class PreTrainedTokenizer(object): def save_pretrained(self, save_directory): - """ Save the tokenizer vocabulary files (with added tokens) and the - special-tokens-to-class-attributes-mapping to a directory. + """ Save the tokenizer vocabulary files together with: + - added tokens, + - special-tokens-to-class-attributes-mapping, + - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert). + + This won't save modifications other than (added tokens and special token mapping) you may have + applied to the tokenizer after the instantion (e.g. modifying tokenizer.do_lower_case after creation). This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method. """ @@ -366,6 +395,13 @@ class PreTrainedTokenizer(object): special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) + tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE) + + tokenizer_config = copy.deepcopy(self.init_kwargs) + tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs) + + with open(tokenizer_config_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(tokenizer_config, ensure_ascii=False)) with open(special_tokens_map_file, 'w', encoding='utf-8') as f: f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) From 69da972ace6fd574a528ef269ebcee32305d18ff Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 17:09:36 +0200 Subject: [PATCH 10/11] added test and debug tokenizer configuration serialization --- .../tests/tokenization_bert_test.py | 4 ++-- .../tests/tokenization_gpt2_test.py | 5 +++-- .../tests/tokenization_openai_test.py | 4 ++-- .../tests/tokenization_roberta_test.py | 5 +++-- .../tests/tokenization_tests_commons.py | 15 ++++++++++++--- .../tests/tokenization_transfo_xl_test.py | 5 +++-- .../tests/tokenization_xlm_test.py | 4 ++-- .../tests/tokenization_xlnet_test.py | 4 ++-- pytorch_transformers/tokenization_utils.py | 4 +++- 9 files changed, 32 insertions(+), 18 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index db507317a8e..290b3578209 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -41,8 +41,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - def get_tokenizer(self): - return BertTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"UNwant\u00E9d,running" diff --git a/pytorch_transformers/tests/tokenization_gpt2_test.py b/pytorch_transformers/tests/tokenization_gpt2_test.py index da7028c27d7..252dbfe6f47 100644 --- a/pytorch_transformers/tests/tokenization_gpt2_test.py +++ b/pytorch_transformers/tests/tokenization_gpt2_test.py @@ -44,8 +44,9 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_openai_test.py b/pytorch_transformers/tests/tokenization_openai_test.py index bb354f3fb77..6b86416d2d6 100644 --- a/pytorch_transformers/tests/tokenization_openai_test.py +++ b/pytorch_transformers/tests/tokenization_openai_test.py @@ -45,8 +45,8 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_roberta_test.py b/pytorch_transformers/tests/tokenization_roberta_test.py index a8f940ae432..5f9b65a7a30 100644 --- a/pytorch_transformers/tests/tokenization_roberta_test.py +++ b/pytorch_transformers/tests/tokenization_roberta_test.py @@ -43,8 +43,9 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index ebcf6f48d87..779a3ba6c3f 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -49,14 +49,19 @@ class CommonTestCases: def tearDown(self): shutil.rmtree(self.tmpdirname) - def get_tokenizer(self): + def get_tokenizer(self, **kwargs): raise NotImplementedError def get_input_output_texts(self): raise NotImplementedError def test_save_and_load_tokenizer(self): + # safety check on max_len default value so we are sure the test works tokenizer = self.get_tokenizer() + self.assertNotEqual(tokenizer.max_len, 42) + + # Now let's start the test + tokenizer = self.get_tokenizer(max_len=42) before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") @@ -64,8 +69,12 @@ class CommonTestCases: tokenizer.save_pretrained(tmpdirname) tokenizer = tokenizer.from_pretrained(tmpdirname) - after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") - self.assertListEqual(before_tokens, after_tokens) + after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") + self.assertListEqual(before_tokens, after_tokens) + + self.assertEqual(tokenizer.max_len, 42) + tokenizer = tokenizer.from_pretrained(tmpdirname, max_len=43) + self.assertEqual(tokenizer.max_len, 43) def test_pickle_tokenizer(self): tokenizer = self.get_tokenizer() diff --git a/pytorch_transformers/tests/tokenization_transfo_xl_test.py b/pytorch_transformers/tests/tokenization_transfo_xl_test.py index fbd06cf47e7..f881cf1d2b4 100644 --- a/pytorch_transformers/tests/tokenization_transfo_xl_test.py +++ b/pytorch_transformers/tests/tokenization_transfo_xl_test.py @@ -37,8 +37,9 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - def get_tokenizer(self): - return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True) + def get_tokenizer(self, **kwargs): + kwargs['lower_case'] = True + return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u" UNwanted , running" diff --git a/pytorch_transformers/tests/tokenization_xlm_test.py b/pytorch_transformers/tests/tokenization_xlm_test.py index ede77a1f988..43f1e0c5dd7 100644 --- a/pytorch_transformers/tests/tokenization_xlm_test.py +++ b/pytorch_transformers/tests/tokenization_xlm_test.py @@ -44,8 +44,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): with open(self.merges_file, "w") as fp: fp.write("\n".join(merges)) - def get_tokenizer(self): - return XLMTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"lower newer" diff --git a/pytorch_transformers/tests/tokenization_xlnet_test.py b/pytorch_transformers/tests/tokenization_xlnet_test.py index 9feab7c0bdf..c603ce55f9d 100644 --- a/pytorch_transformers/tests/tokenization_xlnet_test.py +++ b/pytorch_transformers/tests/tokenization_xlnet_test.py @@ -35,8 +35,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokenizer.save_pretrained(self.tmpdirname) - def get_tokenizer(self): - return XLNetTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs) def get_input_output_texts(self): input_text = u"This is a test" diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 51e59fe46c3..8d7c29b16c0 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -332,7 +332,7 @@ class PreTrainedTokenizer(object): tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None) if tokenizer_config_file is not None: init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8")) - saved_init_inputs = init_kwargs.pop('init_inputs', []) + saved_init_inputs = init_kwargs.pop('init_inputs', ()) if not init_inputs: init_inputs = saved_init_inputs else: @@ -399,6 +399,8 @@ class PreTrainedTokenizer(object): tokenizer_config = copy.deepcopy(self.init_kwargs) tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs) + for file_id in self.vocab_files_names.keys(): + tokenizer_config.pop(file_id, None) with open(tokenizer_config_file, 'w', encoding='utf-8') as f: f.write(json.dumps(tokenizer_config, ensure_ascii=False)) From 7044ed6b059c7305b0a1ab8576c775829afd9226 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 30 Aug 2019 17:36:11 +0200 Subject: [PATCH 11/11] fix tokenizers serialization --- pytorch_transformers/tests/tokenization_dilbert_test.py | 4 ++-- pytorch_transformers/tests/tokenization_tests_commons.py | 4 ++-- pytorch_transformers/tokenization_transfo_xl.py | 3 ++- pytorch_transformers/tokenization_xlnet.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_dilbert_test.py b/pytorch_transformers/tests/tokenization_dilbert_test.py index 30268db2166..42f80609981 100644 --- a/pytorch_transformers/tests/tokenization_dilbert_test.py +++ b/pytorch_transformers/tests/tokenization_dilbert_test.py @@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest): tokenizer_class = DistilBertTokenizer - def get_tokenizer(self): - return DistilBertTokenizer.from_pretrained(self.tmpdirname) + def get_tokenizer(self, **kwargs): + return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs) def test_sequence_builders(self): tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 779a3ba6c3f..6578c5c3a56 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -67,13 +67,13 @@ class CommonTestCases: with TemporaryDirectory() as tmpdirname: tokenizer.save_pretrained(tmpdirname) - tokenizer = tokenizer.from_pretrained(tmpdirname) + tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") self.assertListEqual(before_tokens, after_tokens) self.assertEqual(tokenizer.max_len, 42) - tokenizer = tokenizer.from_pretrained(tmpdirname, max_len=43) + tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43) self.assertEqual(tokenizer.max_len, 43) def test_pickle_tokenizer(self): diff --git a/pytorch_transformers/tokenization_transfo_xl.py b/pytorch_transformers/tokenization_transfo_xl.py index c603ba695c1..66bc01c1bb0 100644 --- a/pytorch_transformers/tokenization_transfo_xl.py +++ b/pytorch_transformers/tokenization_transfo_xl.py @@ -95,7 +95,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer): # in a library like ours, at all. vocab_dict = torch.load(pretrained_vocab_file) for key, value in vocab_dict.items(): - self.__dict__[key] = value + if key not in self.__dict__: + self.__dict__[key] = value if vocab_file is not None: self.build_vocab() diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index ac7231bb680..bf9b9dc782f 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -61,7 +61,7 @@ class XLNetTokenizer(PreTrainedTokenizer): pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, vocab_file, max_len=None, + def __init__(self, vocab_file, do_lower_case=False, remove_space=True, keep_accents=False, bos_token="", eos_token="", unk_token="", sep_token="", pad_token="", cls_token="", mask_token="",