mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Added option to setup pretrained tokenizer arguments
This commit is contained in:
parent
ca4baf8ca1
commit
82462c5cba
@ -63,6 +63,23 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|||||||
'bert-base-cased-finetuned-mrpc': 512,
|
'bert-base-cased-finetuned-mrpc': 512,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
|
'bert-base-uncased': {'do_lower_case': True},
|
||||||
|
'bert-large-uncased': {'do_lower_case': True},
|
||||||
|
'bert-base-cased': {'do_lower_case': False},
|
||||||
|
'bert-large-cased': {'do_lower_case': False},
|
||||||
|
'bert-base-multilingual-uncased': {'do_lower_case': True},
|
||||||
|
'bert-base-multilingual-cased': {'do_lower_case': False},
|
||||||
|
'bert-base-chinese': {'do_lower_case': False},
|
||||||
|
'bert-base-german-cased': {'do_lower_case': False},
|
||||||
|
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
|
||||||
|
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
|
||||||
|
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
|
||||||
|
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
|
||||||
|
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_vocab(vocab_file):
|
def load_vocab(vocab_file):
|
||||||
"""Loads a vocabulary file into a dictionary."""
|
"""Loads a vocabulary file into a dictionary."""
|
||||||
vocab = collections.OrderedDict()
|
vocab = collections.OrderedDict()
|
||||||
@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
vocab_files_names = VOCAB_FILES_NAMES
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||||
@ -199,24 +217,6 @@ class BertTokenizer(PreTrainedTokenizer):
|
|||||||
index += 1
|
index += 1
|
||||||
return (vocab_file,)
|
return (vocab_file,)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
|
||||||
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
|
|
||||||
"""
|
|
||||||
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
|
|
||||||
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
|
|
||||||
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
|
|
||||||
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
|
|
||||||
"you may want to check this behavior.")
|
|
||||||
kwargs['do_lower_case'] = False
|
|
||||||
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
|
|
||||||
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
|
|
||||||
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
|
|
||||||
"but you may want to check this behavior.")
|
|
||||||
kwargs['do_lower_case'] = True
|
|
||||||
|
|
||||||
return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class BasicTokenizer(object):
|
class BasicTokenizer(object):
|
||||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||||
|
@ -40,6 +40,7 @@ class PreTrainedTokenizer(object):
|
|||||||
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
|
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
|
||||||
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
|
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
|
||||||
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
|
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
|
||||||
|
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
|
||||||
@ -61,6 +62,7 @@ class PreTrainedTokenizer(object):
|
|||||||
"""
|
"""
|
||||||
vocab_files_names = {}
|
vocab_files_names = {}
|
||||||
pretrained_vocab_files_map = {}
|
pretrained_vocab_files_map = {}
|
||||||
|
pretrained_init_configuration = {}
|
||||||
max_model_input_sizes = {}
|
max_model_input_sizes = {}
|
||||||
|
|
||||||
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
|
SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
|
||||||
@ -235,10 +237,13 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
s3_models = list(cls.max_model_input_sizes.keys())
|
s3_models = list(cls.max_model_input_sizes.keys())
|
||||||
vocab_files = {}
|
vocab_files = {}
|
||||||
|
init_configuration = {}
|
||||||
if pretrained_model_name_or_path in s3_models:
|
if pretrained_model_name_or_path in s3_models:
|
||||||
# Get the vocabulary from AWS S3 bucket
|
# Get the vocabulary from AWS S3 bucket
|
||||||
for file_id, map_list in cls.pretrained_vocab_files_map.items():
|
for file_id, map_list in cls.pretrained_vocab_files_map.items():
|
||||||
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
|
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
|
||||||
|
if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
|
||||||
|
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
# Get the vocabulary from local files
|
# Get the vocabulary from local files
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -312,28 +317,32 @@ class PreTrainedTokenizer(object):
|
|||||||
logger.info("loading file {} from cache at {}".format(
|
logger.info("loading file {} from cache at {}".format(
|
||||||
file_path, resolved_vocab_files[file_id]))
|
file_path, resolved_vocab_files[file_id]))
|
||||||
|
|
||||||
|
# Prepare initialization kwargs
|
||||||
|
init_kwargs = init_configuration
|
||||||
|
init_kwargs.update(kwargs)
|
||||||
|
|
||||||
# Set max length if needed
|
# Set max length if needed
|
||||||
if pretrained_model_name_or_path in cls.max_model_input_sizes:
|
if pretrained_model_name_or_path in cls.max_model_input_sizes:
|
||||||
# if we're using a pretrained model, ensure the tokenizer
|
# if we're using a pretrained model, ensure the tokenizer
|
||||||
# wont index sequences longer than the number of positional embeddings
|
# wont index sequences longer than the number of positional embeddings
|
||||||
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
|
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
|
||||||
if max_len is not None and isinstance(max_len, (int, float)):
|
if max_len is not None and isinstance(max_len, (int, float)):
|
||||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
|
||||||
|
|
||||||
# Merge resolved_vocab_files arguments in kwargs.
|
# Merge resolved_vocab_files arguments in init_kwargs.
|
||||||
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
|
||||||
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
|
special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
|
||||||
for args_name, file_path in resolved_vocab_files.items():
|
for args_name, file_path in resolved_vocab_files.items():
|
||||||
if args_name not in kwargs:
|
if args_name not in init_kwargs:
|
||||||
kwargs[args_name] = file_path
|
init_kwargs[args_name] = file_path
|
||||||
if special_tokens_map_file is not None:
|
if special_tokens_map_file is not None:
|
||||||
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
|
special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
|
||||||
for key, value in special_tokens_map.items():
|
for key, value in special_tokens_map.items():
|
||||||
if key not in kwargs:
|
if key not in init_kwargs:
|
||||||
kwargs[key] = value
|
init_kwargs[key] = value
|
||||||
|
|
||||||
# Instantiate tokenizer.
|
# Instantiate tokenizer.
|
||||||
tokenizer = cls(*inputs, **kwargs)
|
tokenizer = cls(*inputs, **init_kwargs)
|
||||||
|
|
||||||
# Add supplementary tokens.
|
# Add supplementary tokens.
|
||||||
if added_tokens_file is not None:
|
if added_tokens_file is not None:
|
||||||
|
@ -47,7 +47,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
|
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
|
||||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
|
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
|
||||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
|
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
|
||||||
},
|
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
|
||||||
|
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
|
||||||
|
}
|
||||||
'merges_file':
|
'merges_file':
|
||||||
{
|
{
|
||||||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
||||||
@ -58,6 +60,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||||||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
|
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
|
||||||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
|
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
|
||||||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
|
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
|
||||||
|
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
|
||||||
|
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,6 +74,101 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|||||||
'xlm-mlm-xnli15-1024': 512,
|
'xlm-mlm-xnli15-1024': 512,
|
||||||
'xlm-clm-enfr-1024': 512,
|
'xlm-clm-enfr-1024': 512,
|
||||||
'xlm-clm-ende-1024': 512,
|
'xlm-clm-ende-1024': 512,
|
||||||
|
'xlm-mlm-17-1280': 512,
|
||||||
|
'xlm-mlm-100-1280': 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
PRETRAINED_INIT_CONFIGURATION = {
|
||||||
|
'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True},
|
||||||
|
'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True,
|
||||||
|
"id2lang": { "0": "de",
|
||||||
|
"1": "en"},
|
||||||
|
"lang2id": { "de": 0,
|
||||||
|
"en": 1 }},
|
||||||
|
'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True,
|
||||||
|
"id2lang": { "0": "en",
|
||||||
|
"1": "fr"},
|
||||||
|
"lang2id": { "en": 0,
|
||||||
|
"fr": 1 }},
|
||||||
|
'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True,
|
||||||
|
"id2lang": { "0": "en",
|
||||||
|
"1": "ro"},
|
||||||
|
"lang2id": { "en": 0,
|
||||||
|
"ro": 1 }},
|
||||||
|
'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
|
||||||
|
"id2lang": { "0": "ar",
|
||||||
|
"1": "bg",
|
||||||
|
"2": "de",
|
||||||
|
"3": "el",
|
||||||
|
"4": "en",
|
||||||
|
"5": "es",
|
||||||
|
"6": "fr",
|
||||||
|
"7": "hi",
|
||||||
|
"8": "ru",
|
||||||
|
"9": "sw",
|
||||||
|
"10": "th",
|
||||||
|
"11": "tr",
|
||||||
|
"12": "ur",
|
||||||
|
"13": "vi",
|
||||||
|
"14": "zh"},
|
||||||
|
"lang2id": { "ar": 0,
|
||||||
|
"bg": 1,
|
||||||
|
"de": 2,
|
||||||
|
"el": 3,
|
||||||
|
"en": 4,
|
||||||
|
"es": 5,
|
||||||
|
"fr": 6,
|
||||||
|
"hi": 7,
|
||||||
|
"ru": 8,
|
||||||
|
"sw": 9,
|
||||||
|
"th": 10,
|
||||||
|
"tr": 11,
|
||||||
|
"ur": 12,
|
||||||
|
"vi": 13,
|
||||||
|
"zh": 14 }},
|
||||||
|
'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
|
||||||
|
"id2lang": { "0": "ar",
|
||||||
|
"1": "bg",
|
||||||
|
"2": "de",
|
||||||
|
"3": "el",
|
||||||
|
"4": "en",
|
||||||
|
"5": "es",
|
||||||
|
"6": "fr",
|
||||||
|
"7": "hi",
|
||||||
|
"8": "ru",
|
||||||
|
"9": "sw",
|
||||||
|
"10": "th",
|
||||||
|
"11": "tr",
|
||||||
|
"12": "ur",
|
||||||
|
"13": "vi",
|
||||||
|
"14": "zh"},
|
||||||
|
"lang2id": { "ar": 0,
|
||||||
|
"bg": 1,
|
||||||
|
"de": 2,
|
||||||
|
"el": 3,
|
||||||
|
"en": 4,
|
||||||
|
"es": 5,
|
||||||
|
"fr": 6,
|
||||||
|
"hi": 7,
|
||||||
|
"ru": 8,
|
||||||
|
"sw": 9,
|
||||||
|
"th": 10,
|
||||||
|
"tr": 11,
|
||||||
|
"ur": 12,
|
||||||
|
"vi": 13,
|
||||||
|
"zh": 14 }},
|
||||||
|
'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True,
|
||||||
|
"id2lang": { "0": "en",
|
||||||
|
"1": "fr"},
|
||||||
|
"lang2id": { "en": 0,
|
||||||
|
"fr": 1 }},
|
||||||
|
'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True,
|
||||||
|
"id2lang": { "0": "de",
|
||||||
|
"1": "en"},
|
||||||
|
"lang2id": { "de": 0,
|
||||||
|
"en": 1 }},
|
||||||
|
'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False},
|
||||||
|
'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_pairs(word):
|
def get_pairs(word):
|
||||||
@ -183,17 +282,26 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
- (optionally) lower case & normalize all inputs text
|
- (optionally) lower case & normalize all inputs text
|
||||||
|
|
||||||
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
|
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
|
||||||
(ex: "__classify__") to a vocabulary.
|
(ex: "__classify__") to a vocabulary
|
||||||
|
|
||||||
|
- `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies)
|
||||||
|
|
||||||
|
- `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies)
|
||||||
|
|
||||||
|
- `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies)
|
||||||
"""
|
"""
|
||||||
vocab_files_names = VOCAB_FILES_NAMES
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
|
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
|
||||||
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
|
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
|
||||||
mask_token="<special1>", additional_special_tokens=["<special0>",
|
mask_token="<special1>", additional_special_tokens=["<special0>",
|
||||||
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
|
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
|
||||||
"<special6>", "<special7>", "<special8>", "<special9>"], **kwargs):
|
"<special6>", "<special7>", "<special8>", "<special9>"],
|
||||||
|
lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True,
|
||||||
|
**kwargs):
|
||||||
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
|
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
|
||||||
sep_token=sep_token, pad_token=pad_token,
|
sep_token=sep_token, pad_token=pad_token,
|
||||||
cls_token=cls_token, mask_token=mask_token,
|
cls_token=cls_token, mask_token=mask_token,
|
||||||
@ -206,7 +314,12 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
self.cache_moses_tokenizer = dict()
|
self.cache_moses_tokenizer = dict()
|
||||||
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
|
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
|
||||||
# True for current supported model (v1.2.0), False for XLM-17 & 100
|
# True for current supported model (v1.2.0), False for XLM-17 & 100
|
||||||
self.do_lowercase_and_remove_accent = True
|
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
|
||||||
|
self.lang2id = lang2id
|
||||||
|
self.id2lang = id2lang
|
||||||
|
if lang2id is not None and id2lang is not None:
|
||||||
|
assert len(lang2id) == len(id2lang)
|
||||||
|
|
||||||
self.ja_word_tokenizer = None
|
self.ja_word_tokenizer = None
|
||||||
self.zh_word_tokenizer = None
|
self.zh_word_tokenizer = None
|
||||||
|
|
||||||
@ -244,14 +357,14 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
try:
|
try:
|
||||||
import Mykytea
|
import Mykytea
|
||||||
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
|
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
|
||||||
except:
|
except (AttributeError, ImportError) as e:
|
||||||
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps")
|
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps")
|
||||||
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
|
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
|
||||||
logger.error("2. autoreconf -i")
|
logger.error("2. autoreconf -i")
|
||||||
logger.error("3. ./configure --prefix=$HOME/local")
|
logger.error("3. ./configure --prefix=$HOME/local")
|
||||||
logger.error("4. make && make install")
|
logger.error("4. make && make install")
|
||||||
logger.error("5. pip install kytea")
|
logger.error("5. pip install kytea")
|
||||||
import sys; sys.exit()
|
raise e
|
||||||
return list(self.ja_word_tokenizer.getWS(text))
|
return list(self.ja_word_tokenizer.getWS(text))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -336,6 +449,8 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
Returns:
|
Returns:
|
||||||
List of tokens.
|
List of tokens.
|
||||||
"""
|
"""
|
||||||
|
if lang and self.lang2id and lang not in self.lang2id:
|
||||||
|
logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.")
|
||||||
if bypass_tokenizer:
|
if bypass_tokenizer:
|
||||||
text = text.split()
|
text = text.split()
|
||||||
elif lang not in self.lang_with_custom_tokenizer:
|
elif lang not in self.lang_with_custom_tokenizer:
|
||||||
@ -349,19 +464,19 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
try:
|
try:
|
||||||
if 'pythainlp' not in sys.modules:
|
if 'pythainlp' not in sys.modules:
|
||||||
from pythainlp.tokenize import word_tokenize as th_word_tokenize
|
from pythainlp.tokenize import word_tokenize as th_word_tokenize
|
||||||
except:
|
except (AttributeError, ImportError) as e:
|
||||||
logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps")
|
logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps")
|
||||||
logger.error("1. pip install pythainlp")
|
logger.error("1. pip install pythainlp")
|
||||||
import sys; sys.exit()
|
raise e
|
||||||
text = th_word_tokenize(text)
|
text = th_word_tokenize(text)
|
||||||
elif lang == 'zh':
|
elif lang == 'zh':
|
||||||
try:
|
try:
|
||||||
if 'jieba' not in sys.modules:
|
if 'jieba' not in sys.modules:
|
||||||
import jieba
|
import jieba
|
||||||
except:
|
except (AttributeError, ImportError) as e:
|
||||||
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
|
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
|
||||||
logger.error("1. pip install jieba")
|
logger.error("1. pip install jieba")
|
||||||
import sys; sys.exit()
|
raise e
|
||||||
text = ' '.join(jieba.cut(text))
|
text = ' '.join(jieba.cut(text))
|
||||||
text = self.moses_pipeline(text, lang=lang)
|
text = self.moses_pipeline(text, lang=lang)
|
||||||
text = text.split()
|
text = text.split()
|
||||||
|
Loading…
Reference in New Issue
Block a user