mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Added option to setup pretrained tokenizer arguments
This commit is contained in:
parent
ca4baf8ca1
commit
82462c5cba
@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'bert-base-cased-finetuned-mrpc': 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
'bert-base-uncased': {'do_lower_case': True},
|
||||
'bert-large-uncased': {'do_lower_case': True},
|
||||
'bert-base-cased': {'do_lower_case': False},
|
||||
'bert-large-cased': {'do_lower_case': False},
|
||||
'bert-base-multilingual-uncased': {'do_lower_case': True},
|
||||
'bert-base-multilingual-cased': {'do_lower_case': False},
|
||||
'bert-base-chinese': {'do_lower_case': False},
|
||||
'bert-base-german-cased': {'do_lower_case': False},
|
||||
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
|
||||
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
|
||||
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
|
||||
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
|
||||
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
|
||||
}
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||
@ -199,24 +217,6 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
index += 1
|
||||
return (vocab_file,)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
|
||||
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
|
||||
"you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = False
|
||||
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
|
||||
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
|
||||
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
|
||||
"but you may want to check this behavior.")
|
||||
kwargs['do_lower_case'] = True
|
||||
|
||||
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
@ -40,6 +40,7 @@ class PreTrainedTokenizer(object):
|
||||
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
|
||||
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
|
||||
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
|
||||
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
|
||||
|
||||
Parameters:
|
||||
|
||||
@ -61,6 +62,7 @@ class PreTrainedTokenizer(object):
|
||||
"""
|
||||
vocab_files_names = {}
|
||||
pretrained_vocab_files_map = {}
|
||||
pretrained_init_configuration = {}
|
||||
max_model_input_sizes = {}
|
||||
|
||||
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
|
||||
@ -235,10 +237,13 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
s3_models = list(cls.max_model_input_sizes.keys())
|
||||
vocab_files = {}
|
||||
init_configuration = {}
|
||||
if pretrained_model_name_or_path in s3_models:
|
||||
# Get the vocabulary from AWS S3 bucket
|
||||
for file_id, map_list in cls.pretrained_vocab_files_map.items():
|
||||
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
|
||||
if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
|
||||
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
|
||||
else:
|
||||
# Get the vocabulary from local files
|
||||
logger.info(
|
||||
@ -312,28 +317,32 @@ class PreTrainedTokenizer(object):
|
||||
logger.info("loading file {} from cache at {}".format(
|
||||
file_path, resolved_vocab_files[file_id]))
|
||||
|
||||
# Prepare initialization kwargs
|
||||
init_kwargs = init_configuration
|
||||
init_kwargs.update(kwargs)
|
||||
|
||||
# Set max length if needed
|
||||
if pretrained_model_name_or_path in cls.max_model_input_sizes:
|
||||
# if we're using a pretrained model, ensure the tokenizer
|
||||
# wont index sequences longer than the number of positional embeddings
|
||||
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
|
||||
if max_len is not None and isinstance(max_len, (int, float)):
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
|
||||
|
||||
# Merge resolved_vocab_files arguments in kwargs.
|
||||
# Merge resolved_vocab_files arguments in init_kwargs.
|
||||
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
||||
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
|
||||
for args_name, file_path in resolved_vocab_files.items():
|
||||
if args_name not in kwargs:
|
||||
kwargs[args_name] = file_path
|
||||
if args_name not in init_kwargs:
|
||||
init_kwargs[args_name] = file_path
|
||||
if special_tokens_map_file is not None:
|
||||
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
|
||||
for key, value in special_tokens_map.items():
|
||||
if key not in kwargs:
|
||||
kwargs[key] = value
|
||||
if key not in init_kwargs:
|
||||
init_kwargs[key] = value
|
||||
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(*inputs, **kwargs)
|
||||
tokenizer = cls(*inputs, **init_kwargs)
|
||||
|
||||
# Add supplementary tokens.
|
||||
if added_tokens_file is not None:
|
||||
|
@ -47,7 +47,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
|
||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
|
||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
|
||||
},
|
||||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
|
||||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
|
||||
}
|
||||
'merges_file':
|
||||
{
|
||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
||||
@ -58,6 +60,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
|
||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
|
||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
|
||||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
|
||||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
|
||||
},
|
||||
}
|
||||
|
||||
@ -70,6 +74,101 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'xlm-mlm-xnli15-1024': 512,
|
||||
'xlm-clm-enfr-1024': 512,
|
||||
'xlm-clm-ende-1024': 512,
|
||||
'xlm-mlm-17-1280': 512,
|
||||
'xlm-mlm-100-1280': 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True},
|
||||
'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "de",
|
||||
"1": "en"},
|
||||
"lang2id": { "de": 0,
|
||||
"en": 1 }},
|
||||
'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "en",
|
||||
"1": "fr"},
|
||||
"lang2id": { "en": 0,
|
||||
"fr": 1 }},
|
||||
'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "en",
|
||||
"1": "ro"},
|
||||
"lang2id": { "en": 0,
|
||||
"ro": 1 }},
|
||||
'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "ar",
|
||||
"1": "bg",
|
||||
"2": "de",
|
||||
"3": "el",
|
||||
"4": "en",
|
||||
"5": "es",
|
||||
"6": "fr",
|
||||
"7": "hi",
|
||||
"8": "ru",
|
||||
"9": "sw",
|
||||
"10": "th",
|
||||
"11": "tr",
|
||||
"12": "ur",
|
||||
"13": "vi",
|
||||
"14": "zh"},
|
||||
"lang2id": { "ar": 0,
|
||||
"bg": 1,
|
||||
"de": 2,
|
||||
"el": 3,
|
||||
"en": 4,
|
||||
"es": 5,
|
||||
"fr": 6,
|
||||
"hi": 7,
|
||||
"ru": 8,
|
||||
"sw": 9,
|
||||
"th": 10,
|
||||
"tr": 11,
|
||||
"ur": 12,
|
||||
"vi": 13,
|
||||
"zh": 14 }},
|
||||
'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "ar",
|
||||
"1": "bg",
|
||||
"2": "de",
|
||||
"3": "el",
|
||||
"4": "en",
|
||||
"5": "es",
|
||||
"6": "fr",
|
||||
"7": "hi",
|
||||
"8": "ru",
|
||||
"9": "sw",
|
||||
"10": "th",
|
||||
"11": "tr",
|
||||
"12": "ur",
|
||||
"13": "vi",
|
||||
"14": "zh"},
|
||||
"lang2id": { "ar": 0,
|
||||
"bg": 1,
|
||||
"de": 2,
|
||||
"el": 3,
|
||||
"en": 4,
|
||||
"es": 5,
|
||||
"fr": 6,
|
||||
"hi": 7,
|
||||
"ru": 8,
|
||||
"sw": 9,
|
||||
"th": 10,
|
||||
"tr": 11,
|
||||
"ur": 12,
|
||||
"vi": 13,
|
||||
"zh": 14 }},
|
||||
'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "en",
|
||||
"1": "fr"},
|
||||
"lang2id": { "en": 0,
|
||||
"fr": 1 }},
|
||||
'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True,
|
||||
"id2lang": { "0": "de",
|
||||
"1": "en"},
|
||||
"lang2id": { "de": 0,
|
||||
"en": 1 }},
|
||||
'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False},
|
||||
'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False},
|
||||
}
|
||||
|
||||
def get_pairs(word):
|
||||
@ -183,17 +282,26 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
- (optionally) lower case & normalize all inputs text
|
||||
|
||||
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
|
||||
(ex: "__classify__") to a vocabulary.
|
||||
(ex: "__classify__") to a vocabulary
|
||||
|
||||
- `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies)
|
||||
|
||||
- `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies)
|
||||
|
||||
- `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies)
|
||||
"""
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
|
||||
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
|
||||
mask_token="<special1>", additional_special_tokens=["<special0>",
|
||||
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
|
||||
"<special6>", "<special7>", "<special8>", "<special9>"], **kwargs):
|
||||
"<special6>", "<special7>", "<special8>", "<special9>"],
|
||||
lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True,
|
||||
**kwargs):
|
||||
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
|
||||
sep_token=sep_token, pad_token=pad_token,
|
||||
cls_token=cls_token, mask_token=mask_token,
|
||||
@ -206,7 +314,12 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
self.cache_moses_tokenizer = dict()
|
||||
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
|
||||
# True for current supported model (v1.2.0), False for XLM-17 & 100
|
||||
self.do_lowercase_and_remove_accent = True
|
||||
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
|
||||
self.lang2id = lang2id
|
||||
self.id2lang = id2lang
|
||||
if lang2id is not None and id2lang is not None:
|
||||
assert len(lang2id) == len(id2lang)
|
||||
|
||||
self.ja_word_tokenizer = None
|
||||
self.zh_word_tokenizer = None
|
||||
|
||||
@ -244,14 +357,14 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
try:
|
||||
import Mykytea
|
||||
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
|
||||
except:
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps")
|
||||
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
|
||||
logger.error("2. autoreconf -i")
|
||||
logger.error("3. ./configure --prefix=$HOME/local")
|
||||
logger.error("4. make && make install")
|
||||
logger.error("5. pip install kytea")
|
||||
import sys; sys.exit()
|
||||
raise e
|
||||
return list(self.ja_word_tokenizer.getWS(text))
|
||||
|
||||
@property
|
||||
@ -336,6 +449,8 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
Returns:
|
||||
List of tokens.
|
||||
"""
|
||||
if lang and self.lang2id and lang not in self.lang2id:
|
||||
logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.")
|
||||
if bypass_tokenizer:
|
||||
text = text.split()
|
||||
elif lang not in self.lang_with_custom_tokenizer:
|
||||
@ -349,19 +464,19 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
try:
|
||||
if 'pythainlp' not in sys.modules:
|
||||
from pythainlp.tokenize import word_tokenize as th_word_tokenize
|
||||
except:
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps")
|
||||
logger.error("1. pip install pythainlp")
|
||||
import sys; sys.exit()
|
||||
raise e
|
||||
text = th_word_tokenize(text)
|
||||
elif lang == 'zh':
|
||||
try:
|
||||
if 'jieba' not in sys.modules:
|
||||
import jieba
|
||||
except:
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
|
||||
logger.error("1. pip install jieba")
|
||||
import sys; sys.exit()
|
||||
raise e
|
||||
text = ' '.join(jieba.cut(text))
|
||||
text = self.moses_pipeline(text, lang=lang)
|
||||
text = text.split()
|
||||
|
Loading…
Reference in New Issue
Block a user