Added option to setup pretrained tokenizer arguments

This commit is contained in:
thomwolf 2019-08-30 15:30:41 +02:00
parent ca4baf8ca1
commit 82462c5cba
3 changed files with 159 additions and 35 deletions

View File

@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'bert-base-cased-finetuned-mrpc': 512,
}
PRETRAINED_INIT_CONFIGURATION = {
'bert-base-uncased': {'do_lower_case': True},
'bert-large-uncased': {'do_lower_case': True},
'bert-base-cased': {'do_lower_case': False},
'bert-large-cased': {'do_lower_case': False},
'bert-base-multilingual-uncased': {'do_lower_case': True},
'bert-base-multilingual-cased': {'do_lower_case': False},
'bert-base-chinese': {'do_lower_case': False},
'bert-base-german-cased': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
}
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
@ -199,24 +217,6 @@ class BertTokenizer(PreTrainedTokenizer):
index += 1
return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
"""
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior.")
kwargs['do_lower_case'] = False
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

View File

@ -40,6 +40,7 @@ class PreTrainedTokenizer(object):
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
Parameters:
@ -61,6 +62,7 @@ class PreTrainedTokenizer(object):
"""
vocab_files_names = {}
pretrained_vocab_files_map = {}
pretrained_init_configuration = {}
max_model_input_sizes = {}
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
@ -235,10 +237,13 @@ class PreTrainedTokenizer(object):
s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
init_configuration = {}
if pretrained_model_name_or_path in s3_models:
# Get the vocabulary from AWS S3 bucket
for file_id, map_list in cls.pretrained_vocab_files_map.items():
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
else:
# Get the vocabulary from local files
logger.info(
@ -312,28 +317,32 @@ class PreTrainedTokenizer(object):
logger.info("loading file {} from cache at {}".format(
file_path, resolved_vocab_files[file_id]))
# Prepare initialization kwargs
init_kwargs = init_configuration
init_kwargs.update(kwargs)
# Set max length if needed
if pretrained_model_name_or_path in cls.max_model_input_sizes:
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
if max_len is not None and isinstance(max_len, (int, float)):
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
# Merge resolved_vocab_files arguments in kwargs.
# Merge resolved_vocab_files arguments in init_kwargs.
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
for args_name, file_path in resolved_vocab_files.items():
if args_name not in kwargs:
kwargs[args_name] = file_path
if args_name not in init_kwargs:
init_kwargs[args_name] = file_path
if special_tokens_map_file is not None:
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
for key, value in special_tokens_map.items():
if key not in kwargs:
kwargs[key] = value
if key not in init_kwargs:
init_kwargs[key] = value
# Instantiate tokenizer.
tokenizer = cls(*inputs, **kwargs)
tokenizer = cls(*inputs, **init_kwargs)
# Add supplementary tokens.
if added_tokens_file is not None:

View File

@ -47,7 +47,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
},
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
}
'merges_file':
{
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
@ -58,6 +60,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
},
}
@ -70,6 +74,101 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-mlm-xnli15-1024': 512,
'xlm-clm-enfr-1024': 512,
'xlm-clm-ende-1024': 512,
'xlm-mlm-17-1280': 512,
'xlm-mlm-100-1280': 512,
}
PRETRAINED_INIT_CONFIGURATION = {
'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True},
'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "de",
"1": "en"},
"lang2id": { "de": 0,
"en": 1 }},
'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "en",
"1": "fr"},
"lang2id": { "en": 0,
"fr": 1 }},
'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "en",
"1": "ro"},
"lang2id": { "en": 0,
"ro": 1 }},
'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "ar",
"1": "bg",
"2": "de",
"3": "el",
"4": "en",
"5": "es",
"6": "fr",
"7": "hi",
"8": "ru",
"9": "sw",
"10": "th",
"11": "tr",
"12": "ur",
"13": "vi",
"14": "zh"},
"lang2id": { "ar": 0,
"bg": 1,
"de": 2,
"el": 3,
"en": 4,
"es": 5,
"fr": 6,
"hi": 7,
"ru": 8,
"sw": 9,
"th": 10,
"tr": 11,
"ur": 12,
"vi": 13,
"zh": 14 }},
'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "ar",
"1": "bg",
"2": "de",
"3": "el",
"4": "en",
"5": "es",
"6": "fr",
"7": "hi",
"8": "ru",
"9": "sw",
"10": "th",
"11": "tr",
"12": "ur",
"13": "vi",
"14": "zh"},
"lang2id": { "ar": 0,
"bg": 1,
"de": 2,
"el": 3,
"en": 4,
"es": 5,
"fr": 6,
"hi": 7,
"ru": 8,
"sw": 9,
"th": 10,
"tr": 11,
"ur": 12,
"vi": 13,
"zh": 14 }},
'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "en",
"1": "fr"},
"lang2id": { "en": 0,
"fr": 1 }},
'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True,
"id2lang": { "0": "de",
"1": "en"},
"lang2id": { "de": 0,
"en": 1 }},
'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False},
'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False},
}
def get_pairs(word):
@ -183,17 +282,26 @@ class XLMTokenizer(PreTrainedTokenizer):
- (optionally) lower case & normalize all inputs text
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
(ex: "__classify__") to a vocabulary.
(ex: "__classify__") to a vocabulary
- `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies)
- `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies)
- `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies)
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
mask_token="<special1>", additional_special_tokens=["<special0>",
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
"<special6>", "<special7>", "<special8>", "<special9>"], **kwargs):
"<special6>", "<special7>", "<special8>", "<special9>"],
lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True,
**kwargs):
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
sep_token=sep_token, pad_token=pad_token,
cls_token=cls_token, mask_token=mask_token,
@ -206,7 +314,12 @@ class XLMTokenizer(PreTrainedTokenizer):
self.cache_moses_tokenizer = dict()
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
# True for current supported model (v1.2.0), False for XLM-17 & 100
self.do_lowercase_and_remove_accent = True
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
self.lang2id = lang2id
self.id2lang = id2lang
if lang2id is not None and id2lang is not None:
assert len(lang2id) == len(id2lang)
self.ja_word_tokenizer = None
self.zh_word_tokenizer = None
@ -244,14 +357,14 @@ class XLMTokenizer(PreTrainedTokenizer):
try:
import Mykytea
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
except:
except (AttributeError, ImportError) as e:
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps")
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
logger.error("2. autoreconf -i")
logger.error("3. ./configure --prefix=$HOME/local")
logger.error("4. make && make install")
logger.error("5. pip install kytea")
import sys; sys.exit()
raise e
return list(self.ja_word_tokenizer.getWS(text))
@property
@ -336,6 +449,8 @@ class XLMTokenizer(PreTrainedTokenizer):
Returns:
List of tokens.
"""
if lang and self.lang2id and lang not in self.lang2id:
logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.")
if bypass_tokenizer:
text = text.split()
elif lang not in self.lang_with_custom_tokenizer:
@ -349,19 +464,19 @@ class XLMTokenizer(PreTrainedTokenizer):
try:
if 'pythainlp' not in sys.modules:
from pythainlp.tokenize import word_tokenize as th_word_tokenize
except:
except (AttributeError, ImportError) as e:
logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps")
logger.error("1. pip install pythainlp")
import sys; sys.exit()
raise e
text = th_word_tokenize(text)
elif lang == 'zh':
try:
if 'jieba' not in sys.modules:
import jieba
except:
except (AttributeError, ImportError) as e:
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
logger.error("1. pip install jieba")
import sys; sys.exit()
raise e
text = ' '.join(jieba.cut(text))
text = self.moses_pipeline(text, lang=lang)
text = text.split()