mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
807 lines
35 KiB
Python
807 lines
35 KiB
Python
# coding=utf-8
|
||
# Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""Tokenization classes for OpenAI GPT."""
|
||
from __future__ import (absolute_import, division, print_function,
|
||
unicode_literals)
|
||
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import sys
|
||
import unicodedata
|
||
from io import open
|
||
|
||
import sacremoses as sm
|
||
|
||
from .tokenization_utils import PreTrainedTokenizer
|
||
from .tokenization_bert import BasicTokenizer
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
VOCAB_FILES_NAMES = {
|
||
'vocab_file': 'vocab.json',
|
||
'merges_file': 'merges.txt',
|
||
}
|
||
|
||
PRETRAINED_VOCAB_FILES_MAP = {
|
||
'vocab_file':
|
||
{
|
||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json",
|
||
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json",
|
||
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json",
|
||
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json",
|
||
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json",
|
||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json",
|
||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json",
|
||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json",
|
||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-vocab.json",
|
||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-vocab.json",
|
||
},
|
||
'merges_file':
|
||
{
|
||
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt",
|
||
'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
|
||
'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
|
||
'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt",
|
||
'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt",
|
||
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt",
|
||
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt",
|
||
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt",
|
||
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-merges.txt",
|
||
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-merges.txt",
|
||
},
|
||
}
|
||
|
||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||
'xlm-mlm-en-2048': 512,
|
||
'xlm-mlm-ende-1024': 512,
|
||
'xlm-mlm-enfr-1024': 512,
|
||
'xlm-mlm-enro-1024': 512,
|
||
'xlm-mlm-tlm-xnli15-1024': 512,
|
||
'xlm-mlm-xnli15-1024': 512,
|
||
'xlm-clm-enfr-1024': 512,
|
||
'xlm-clm-ende-1024': 512,
|
||
'xlm-mlm-17-1280': 512,
|
||
'xlm-mlm-100-1280': 512,
|
||
}
|
||
|
||
PRETRAINED_INIT_CONFIGURATION = {
|
||
'xlm-mlm-en-2048': {"do_lowercase_and_remove_accent": True},
|
||
'xlm-mlm-ende-1024': { "do_lowercase_and_remove_accent": True,
|
||
"id2lang": { "0": "de",
|
||
"1": "en"},
|
||
"lang2id": { "de": 0,
|
||
"en": 1 }},
|
||
'xlm-mlm-enfr-1024': { "do_lowercase_and_remove_accent": True,
|
||
"id2lang": { "0": "en",
|
||
"1": "fr"},
|
||
"lang2id": { "en": 0,
|
||
"fr": 1 }},
|
||
'xlm-mlm-enro-1024': { "do_lowercase_and_remove_accent": True,
|
||
"id2lang": { "0": "en",
|
||
"1": "ro"},
|
||
"lang2id": { "en": 0,
|
||
"ro": 1 }},
|
||
'xlm-mlm-tlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
|
||
"id2lang": { "0": "ar",
|
||
"1": "bg",
|
||
"2": "de",
|
||
"3": "el",
|
||
"4": "en",
|
||
"5": "es",
|
||
"6": "fr",
|
||
"7": "hi",
|
||
"8": "ru",
|
||
"9": "sw",
|
||
"10": "th",
|
||
"11": "tr",
|
||
"12": "ur",
|
||
"13": "vi",
|
||
"14": "zh"},
|
||
"lang2id": { "ar": 0,
|
||
"bg": 1,
|
||
"de": 2,
|
||
"el": 3,
|
||
"en": 4,
|
||
"es": 5,
|
||
"fr": 6,
|
||
"hi": 7,
|
||
"ru": 8,
|
||
"sw": 9,
|
||
"th": 10,
|
||
"tr": 11,
|
||
"ur": 12,
|
||
"vi": 13,
|
||
"zh": 14 }},
|
||
'xlm-mlm-xnli15-1024': { "do_lowercase_and_remove_accent": True,
|
||
"id2lang": { "0": "ar",
|
||
"1": "bg",
|
||
"2": "de",
|
||
"3": "el",
|
||
"4": "en",
|
||
"5": "es",
|
||
"6": "fr",
|
||
"7": "hi",
|
||
"8": "ru",
|
||
"9": "sw",
|
||
"10": "th",
|
||
"11": "tr",
|
||
"12": "ur",
|
||
"13": "vi",
|
||
"14": "zh"},
|
||
"lang2id": { "ar": 0,
|
||
"bg": 1,
|
||
"de": 2,
|
||
"el": 3,
|
||
"en": 4,
|
||
"es": 5,
|
||
"fr": 6,
|
||
"hi": 7,
|
||
"ru": 8,
|
||
"sw": 9,
|
||
"th": 10,
|
||
"tr": 11,
|
||
"ur": 12,
|
||
"vi": 13,
|
||
"zh": 14 }},
|
||
'xlm-clm-enfr-1024': { "do_lowercase_and_remove_accent": True,
|
||
"id2lang": { "0": "en",
|
||
"1": "fr"},
|
||
"lang2id": { "en": 0,
|
||
"fr": 1 }},
|
||
'xlm-clm-ende-1024': { "do_lowercase_and_remove_accent": True,
|
||
"id2lang": { "0": "de",
|
||
"1": "en"},
|
||
"lang2id": { "de": 0,
|
||
"en": 1 }},
|
||
'xlm-mlm-17-1280': {"do_lowercase_and_remove_accent": False,
|
||
"id2lang": {
|
||
"0": "ar",
|
||
"1": "de",
|
||
"2": "en",
|
||
"3": "es",
|
||
"4": "fr",
|
||
"5": "hi",
|
||
"6": "it",
|
||
"7": "ja",
|
||
"8": "ko",
|
||
"9": "nl",
|
||
"10": "pl",
|
||
"11": "pt",
|
||
"12": "ru",
|
||
"13": "sv",
|
||
"14": "tr",
|
||
"15": "vi",
|
||
"16": "zh"
|
||
},
|
||
"lang2id": {
|
||
"ar": 0,
|
||
"de": 1,
|
||
"en": 2,
|
||
"es": 3,
|
||
"fr": 4,
|
||
"hi": 5,
|
||
"it": 6,
|
||
"ja": 7,
|
||
"ko": 8,
|
||
"nl": 9,
|
||
"pl": 10,
|
||
"pt": 11,
|
||
"ru": 12,
|
||
"sv": 13,
|
||
"tr": 14,
|
||
"vi": 15,
|
||
"zh": 16}},
|
||
'xlm-mlm-100-1280': {"do_lowercase_and_remove_accent": False,
|
||
"id2lang": {
|
||
"0": "af",
|
||
"1": "als",
|
||
"2": "am",
|
||
"3": "an",
|
||
"4": "ang",
|
||
"5": "ar",
|
||
"6": "arz",
|
||
"7": "ast",
|
||
"8": "az",
|
||
"9": "bar",
|
||
"10": "be",
|
||
"11": "bg",
|
||
"12": "bn",
|
||
"13": "br",
|
||
"14": "bs",
|
||
"15": "ca",
|
||
"16": "ceb",
|
||
"17": "ckb",
|
||
"18": "cs",
|
||
"19": "cy",
|
||
"20": "da",
|
||
"21": "de",
|
||
"22": "el",
|
||
"23": "en",
|
||
"24": "eo",
|
||
"25": "es",
|
||
"26": "et",
|
||
"27": "eu",
|
||
"28": "fa",
|
||
"29": "fi",
|
||
"30": "fr",
|
||
"31": "fy",
|
||
"32": "ga",
|
||
"33": "gan",
|
||
"34": "gl",
|
||
"35": "gu",
|
||
"36": "he",
|
||
"37": "hi",
|
||
"38": "hr",
|
||
"39": "hu",
|
||
"40": "hy",
|
||
"41": "ia",
|
||
"42": "id",
|
||
"43": "is",
|
||
"44": "it",
|
||
"45": "ja",
|
||
"46": "jv",
|
||
"47": "ka",
|
||
"48": "kk",
|
||
"49": "kn",
|
||
"50": "ko",
|
||
"51": "ku",
|
||
"52": "la",
|
||
"53": "lb",
|
||
"54": "lt",
|
||
"55": "lv",
|
||
"56": "mk",
|
||
"57": "ml",
|
||
"58": "mn",
|
||
"59": "mr",
|
||
"60": "ms",
|
||
"61": "my",
|
||
"62": "nds",
|
||
"63": "ne",
|
||
"64": "nl",
|
||
"65": "nn",
|
||
"66": "no",
|
||
"67": "oc",
|
||
"68": "pl",
|
||
"69": "pt",
|
||
"70": "ro",
|
||
"71": "ru",
|
||
"72": "scn",
|
||
"73": "sco",
|
||
"74": "sh",
|
||
"75": "si",
|
||
"76": "simple",
|
||
"77": "sk",
|
||
"78": "sl",
|
||
"79": "sq",
|
||
"80": "sr",
|
||
"81": "sv",
|
||
"82": "sw",
|
||
"83": "ta",
|
||
"84": "te",
|
||
"85": "th",
|
||
"86": "tl",
|
||
"87": "tr",
|
||
"88": "tt",
|
||
"89": "uk",
|
||
"90": "ur",
|
||
"91": "uz",
|
||
"92": "vi",
|
||
"93": "war",
|
||
"94": "wuu",
|
||
"95": "yi",
|
||
"96": "zh",
|
||
"97": "zh_classical",
|
||
"98": "zh_min_nan",
|
||
"99": "zh_yue"
|
||
},
|
||
"lang2id": {
|
||
"af": 0,
|
||
"als": 1,
|
||
"am": 2,
|
||
"an": 3,
|
||
"ang": 4,
|
||
"ar": 5,
|
||
"arz": 6,
|
||
"ast": 7,
|
||
"az": 8,
|
||
"bar": 9,
|
||
"be": 10,
|
||
"bg": 11,
|
||
"bn": 12,
|
||
"br": 13,
|
||
"bs": 14,
|
||
"ca": 15,
|
||
"ceb": 16,
|
||
"ckb": 17,
|
||
"cs": 18,
|
||
"cy": 19,
|
||
"da": 20,
|
||
"de": 21,
|
||
"el": 22,
|
||
"en": 23,
|
||
"eo": 24,
|
||
"es": 25,
|
||
"et": 26,
|
||
"eu": 27,
|
||
"fa": 28,
|
||
"fi": 29,
|
||
"fr": 30,
|
||
"fy": 31,
|
||
"ga": 32,
|
||
"gan": 33,
|
||
"gl": 34,
|
||
"gu": 35,
|
||
"he": 36,
|
||
"hi": 37,
|
||
"hr": 38,
|
||
"hu": 39,
|
||
"hy": 40,
|
||
"ia": 41,
|
||
"id": 42,
|
||
"is": 43,
|
||
"it": 44,
|
||
"ja": 45,
|
||
"jv": 46,
|
||
"ka": 47,
|
||
"kk": 48,
|
||
"kn": 49,
|
||
"ko": 50,
|
||
"ku": 51,
|
||
"la": 52,
|
||
"lb": 53,
|
||
"lt": 54,
|
||
"lv": 55,
|
||
"mk": 56,
|
||
"ml": 57,
|
||
"mn": 58,
|
||
"mr": 59,
|
||
"ms": 60,
|
||
"my": 61,
|
||
"nds": 62,
|
||
"ne": 63,
|
||
"nl": 64,
|
||
"nn": 65,
|
||
"no": 66,
|
||
"oc": 67,
|
||
"pl": 68,
|
||
"pt": 69,
|
||
"ro": 70,
|
||
"ru": 71,
|
||
"scn": 72,
|
||
"sco": 73,
|
||
"sh": 74,
|
||
"si": 75,
|
||
"simple": 76,
|
||
"sk": 77,
|
||
"sl": 78,
|
||
"sq": 79,
|
||
"sr": 80,
|
||
"sv": 81,
|
||
"sw": 82,
|
||
"ta": 83,
|
||
"te": 84,
|
||
"th": 85,
|
||
"tl": 86,
|
||
"tr": 87,
|
||
"tt": 88,
|
||
"uk": 89,
|
||
"ur": 90,
|
||
"uz": 91,
|
||
"vi": 92,
|
||
"war": 93,
|
||
"wuu": 94,
|
||
"yi": 95,
|
||
"zh": 96,
|
||
"zh_classical": 97,
|
||
"zh_min_nan": 98,
|
||
"zh_yue": 99
|
||
}},
|
||
}
|
||
|
||
def get_pairs(word):
|
||
"""
|
||
Return set of symbol pairs in a word.
|
||
word is represented as tuple of symbols (symbols being variable-length strings)
|
||
"""
|
||
pairs = set()
|
||
prev_char = word[0]
|
||
for char in word[1:]:
|
||
pairs.add((prev_char, char))
|
||
prev_char = char
|
||
return pairs
|
||
|
||
|
||
def lowercase_and_remove_accent(text):
|
||
"""
|
||
Lowercase and strips accents from a piece of text based on
|
||
https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
|
||
"""
|
||
text = ' '.join(text)
|
||
text = text.lower()
|
||
text = unicodedata.normalize("NFD", text)
|
||
output = []
|
||
for char in text:
|
||
cat = unicodedata.category(char)
|
||
if cat == "Mn":
|
||
continue
|
||
output.append(char)
|
||
return "".join(output).lower().split(' ')
|
||
|
||
|
||
def replace_unicode_punct(text):
|
||
'''
|
||
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
|
||
'''
|
||
text = text.replace(',', ',')
|
||
text = re.sub(r'。\s*', '. ', text)
|
||
text = text.replace('、', ',')
|
||
text = text.replace('”', '"')
|
||
text = text.replace('“', '"')
|
||
text = text.replace('∶', ':')
|
||
text = text.replace(':', ':')
|
||
text = text.replace('?', '?')
|
||
text = text.replace('《', '"')
|
||
text = text.replace('》', '"')
|
||
text = text.replace(')', ')')
|
||
text = text.replace('!', '!')
|
||
text = text.replace('(', '(')
|
||
text = text.replace(';', ';')
|
||
text = text.replace('1', '"')
|
||
text = text.replace('」', '"')
|
||
text = text.replace('「', '"')
|
||
text = text.replace('0', '0')
|
||
text = text.replace('3', '3')
|
||
text = text.replace('2', '2')
|
||
text = text.replace('5', '5')
|
||
text = text.replace('6', '6')
|
||
text = text.replace('9', '9')
|
||
text = text.replace('7', '7')
|
||
text = text.replace('8', '8')
|
||
text = text.replace('4', '4')
|
||
text = re.sub(r'.\s*', '. ', text)
|
||
text = text.replace('~', '~')
|
||
text = text.replace('’', '\'')
|
||
text = text.replace('…', '...')
|
||
text = text.replace('━', '-')
|
||
text = text.replace('〈', '<')
|
||
text = text.replace('〉', '>')
|
||
text = text.replace('【', '[')
|
||
text = text.replace('】', ']')
|
||
text = text.replace('%', '%')
|
||
return text
|
||
|
||
|
||
def remove_non_printing_char(text):
|
||
'''
|
||
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
|
||
'''
|
||
output = []
|
||
for char in text:
|
||
cat = unicodedata.category(char)
|
||
if cat.startswith('C'):
|
||
continue
|
||
output.append(char)
|
||
return "".join(output)
|
||
|
||
|
||
def romanian_preprocessing(text):
|
||
'''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`'''
|
||
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
|
||
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
|
||
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
|
||
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
|
||
text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma
|
||
text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma
|
||
text = text.replace("\u0102", "A").replace("\u0103", "a")
|
||
text = text.replace("\u00C2", "A").replace("\u00E2", "a")
|
||
text = text.replace("\u00CE", "I").replace("\u00EE", "i")
|
||
return text
|
||
|
||
|
||
class XLMTokenizer(PreTrainedTokenizer):
|
||
"""
|
||
BPE tokenizer for XLM
|
||
|
||
- Moses preprocessing & tokenization for most supported languages
|
||
|
||
- Language specific tokenization for Chinese (Jieba), Japanese (KyTea) and Thai (PyThaiNLP)
|
||
|
||
- (optionally) lower case & normalize all inputs text
|
||
|
||
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \
|
||
(ex: "__classify__") to a vocabulary
|
||
|
||
- `lang2id` attribute maps the languages supported by the model with their ids if provided (automatically set for pretrained vocabularies)
|
||
|
||
- `id2lang` attributes does reverse mapping if provided (automatically set for pretrained vocabularies)
|
||
|
||
- `do_lowercase_and_remove_accent` controle lower casing and accent (automatically set for pretrained vocabularies)
|
||
"""
|
||
vocab_files_names = VOCAB_FILES_NAMES
|
||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||
|
||
def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>",
|
||
sep_token="</s>", pad_token="<pad>", cls_token="</s>",
|
||
mask_token="<special1>", additional_special_tokens=["<special0>",
|
||
"<special1>", "<special2>", "<special3>", "<special4>", "<special5>",
|
||
"<special6>", "<special7>", "<special8>", "<special9>"],
|
||
lang2id=None, id2lang=None, do_lowercase_and_remove_accent=True,
|
||
**kwargs):
|
||
super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token,
|
||
sep_token=sep_token, pad_token=pad_token,
|
||
cls_token=cls_token, mask_token=mask_token,
|
||
additional_special_tokens=additional_special_tokens,
|
||
**kwargs)
|
||
|
||
# cache of sm.MosesPunctNormalizer instance
|
||
self.cache_moses_punct_normalizer = dict()
|
||
# cache of sm.MosesTokenizer instance
|
||
self.cache_moses_tokenizer = dict()
|
||
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
|
||
# True for current supported model (v1.2.0), False for XLM-17 & 100
|
||
self.do_lowercase_and_remove_accent = do_lowercase_and_remove_accent
|
||
self.lang2id = lang2id
|
||
self.id2lang = id2lang
|
||
if lang2id is not None and id2lang is not None:
|
||
assert len(lang2id) == len(id2lang)
|
||
|
||
self.ja_word_tokenizer = None
|
||
self.zh_word_tokenizer = None
|
||
|
||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||
merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1]
|
||
merges = [tuple(merge.split()[:2]) for merge in merges]
|
||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||
self.cache = {}
|
||
|
||
def moses_punct_norm(self, text, lang):
|
||
if lang not in self.cache_moses_punct_normalizer:
|
||
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
|
||
self.cache_moses_punct_normalizer[lang] = punct_normalizer
|
||
else:
|
||
punct_normalizer = self.cache_moses_punct_normalizer[lang]
|
||
return punct_normalizer.normalize(text)
|
||
|
||
def moses_tokenize(self, text, lang):
|
||
if lang not in self.cache_moses_tokenizer:
|
||
moses_tokenizer = sm.MosesTokenizer(lang=lang)
|
||
self.cache_moses_tokenizer[lang] = moses_tokenizer
|
||
else:
|
||
moses_tokenizer = self.cache_moses_tokenizer[lang]
|
||
return moses_tokenizer.tokenize(text, return_str=False, escape=False)
|
||
|
||
def moses_pipeline(self, text, lang):
|
||
text = replace_unicode_punct(text)
|
||
text = self.moses_punct_norm(text, lang)
|
||
text = remove_non_printing_char(text)
|
||
return text
|
||
|
||
def ja_tokenize(self, text):
|
||
if self.ja_word_tokenizer is None:
|
||
try:
|
||
import Mykytea
|
||
self.ja_word_tokenizer = Mykytea.Mykytea('-model %s/local/share/kytea/model.bin' % os.path.expanduser('~'))
|
||
except (AttributeError, ImportError) as e:
|
||
logger.error("Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper (https://github.com/chezou/Mykytea-python) with the following steps")
|
||
logger.error("1. git clone git@github.com:neubig/kytea.git && cd kytea")
|
||
logger.error("2. autoreconf -i")
|
||
logger.error("3. ./configure --prefix=$HOME/local")
|
||
logger.error("4. make && make install")
|
||
logger.error("5. pip install kytea")
|
||
raise e
|
||
return list(self.ja_word_tokenizer.getWS(text))
|
||
|
||
@property
|
||
def vocab_size(self):
|
||
return len(self.encoder)
|
||
|
||
def bpe(self, token):
|
||
word = tuple(token[:-1]) + (token[-1] + '</w>',)
|
||
if token in self.cache:
|
||
return self.cache[token]
|
||
pairs = get_pairs(word)
|
||
|
||
if not pairs:
|
||
return token+'</w>'
|
||
|
||
while True:
|
||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||
if bigram not in self.bpe_ranks:
|
||
break
|
||
first, second = bigram
|
||
new_word = []
|
||
i = 0
|
||
while i < len(word):
|
||
try:
|
||
j = word.index(first, i)
|
||
new_word.extend(word[i:j])
|
||
i = j
|
||
except:
|
||
new_word.extend(word[i:])
|
||
break
|
||
|
||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||
new_word.append(first+second)
|
||
i += 2
|
||
else:
|
||
new_word.append(word[i])
|
||
i += 1
|
||
new_word = tuple(new_word)
|
||
word = new_word
|
||
if len(word) == 1:
|
||
break
|
||
else:
|
||
pairs = get_pairs(word)
|
||
word = ' '.join(word)
|
||
if word == '\n </w>':
|
||
word = '\n</w>'
|
||
self.cache[token] = word
|
||
return word
|
||
|
||
def _tokenize(self, text, lang='en', bypass_tokenizer=False):
|
||
"""
|
||
Tokenize a string given language code. For Chinese, Japanese and Thai, we use a language specific tokenizerself. Otherwise, we use Moses.
|
||
|
||
Details of tokenization:
|
||
- [sacremoses](https://github.com/alvations/sacremoses): port of Moses
|
||
- Install with `pip install sacremoses`
|
||
- [pythainlp](https://github.com/PyThaiNLP/pythainlp): Thai tokenizer
|
||
- Install with `pip install pythainlp`
|
||
- [kytea](https://github.com/chezou/Mykytea-python): Japanese tokenizer, wrapper of [KyTea](https://github.com/neubig/kytea)
|
||
- Install with the following steps:
|
||
```
|
||
git clone git@github.com:neubig/kytea.git && cd kytea
|
||
autoreconf -i
|
||
./configure --prefix=$HOME/local
|
||
make && make install
|
||
pip install kytea
|
||
```
|
||
- [jieba](https://github.com/fxsjy/jieba): Chinese tokenizer *
|
||
- Install with `pip install jieba`
|
||
|
||
\* The original XLM used [Stanford Segmenter](https://nlp.stanford.edu/software/stanford-segmenter-2018-10-16.zip).
|
||
However, the wrapper (`nltk.tokenize.stanford_segmenter`) is slow due to JVM overhead, and it will be deprecated.
|
||
Jieba is a lot faster and pip-installable. Note there is some mismatch with the Stanford Segmenter. It should be fine
|
||
if you fine-tune the model with Chinese supervisionself. If you want the same exact behaviour, use the original XLM
|
||
[preprocessing script](https://github.com/facebookresearch/XLM/tree/master/tools) to tokenize the sentence externally,
|
||
and set `bypass_tokenizer=True` to bypass the tokenizer.
|
||
|
||
Args:
|
||
- lang: ISO language code (default = 'en') (string). Languages should belong of the model supported languages. However, we don't enforce it.
|
||
- bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False) (bool). If True, we only apply BPE.
|
||
|
||
Returns:
|
||
List of tokens.
|
||
"""
|
||
if lang and self.lang2id and lang not in self.lang2id:
|
||
logger.error("Supplied language code not found in lang2id mapping. Please check that your language is supported by the loaded pretrained model.")
|
||
if bypass_tokenizer:
|
||
text = text.split()
|
||
elif lang not in self.lang_with_custom_tokenizer:
|
||
text = self.moses_pipeline(text, lang=lang)
|
||
# TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
|
||
if lang == 'ro':
|
||
text = romanian_preprocessing(text)
|
||
text = self.moses_tokenize(text, lang=lang)
|
||
elif lang == 'th':
|
||
text = self.moses_pipeline(text, lang=lang)
|
||
try:
|
||
if 'pythainlp' not in sys.modules:
|
||
from pythainlp.tokenize import word_tokenize as th_word_tokenize
|
||
else:
|
||
th_word_tokenize = sys.modules['pythainlp'].word_tokenize
|
||
except (AttributeError, ImportError) as e:
|
||
logger.error("Make sure you install PyThaiNLP (https://github.com/PyThaiNLP/pythainlp) with the following steps")
|
||
logger.error("1. pip install pythainlp")
|
||
raise e
|
||
text = th_word_tokenize(text)
|
||
elif lang == 'zh':
|
||
try:
|
||
if 'jieba' not in sys.modules:
|
||
import jieba
|
||
else:
|
||
jieba = sys.modules['jieba']
|
||
except (AttributeError, ImportError) as e:
|
||
logger.error("Make sure you install Jieba (https://github.com/fxsjy/jieba) with the following steps")
|
||
logger.error("1. pip install jieba")
|
||
raise e
|
||
text = ' '.join(jieba.cut(text))
|
||
text = self.moses_pipeline(text, lang=lang)
|
||
text = text.split()
|
||
elif lang == 'ja':
|
||
text = self.moses_pipeline(text, lang=lang)
|
||
text = self.ja_tokenize(text)
|
||
else:
|
||
raise ValueError('It should not reach here')
|
||
|
||
if self.do_lowercase_and_remove_accent and not bypass_tokenizer:
|
||
text = lowercase_and_remove_accent(text)
|
||
|
||
split_tokens = []
|
||
for token in text:
|
||
if token:
|
||
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||
|
||
return split_tokens
|
||
|
||
def _convert_token_to_id(self, token):
|
||
""" Converts a token (str/unicode) in an id using the vocab. """
|
||
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||
|
||
def _convert_id_to_token(self, index):
|
||
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
|
||
return self.decoder.get(index, self.unk_token)
|
||
|
||
def convert_tokens_to_string(self, tokens):
|
||
""" Converts a sequence of tokens (string) in a single string. """
|
||
out_string = ''.join(tokens).replace('</w>', ' ').strip()
|
||
return out_string
|
||
|
||
def add_special_tokens_single_sequence(self, token_ids):
|
||
"""
|
||
Adds special tokens to a sequence for sequence classification tasks.
|
||
An XLM sequence has the following format: [CLS] X [SEP]
|
||
"""
|
||
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
||
|
||
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
|
||
"""
|
||
Adds special tokens to a sequence pair for sequence classification tasks.
|
||
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
||
"""
|
||
sep = [self.sep_token_id]
|
||
cls = [self.cls_token_id]
|
||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||
|
||
def create_mask_from_sequences(self, sequence_0, sequence_1):
|
||
"""
|
||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
||
An XLM sequence pair mask has the following format:
|
||
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
|
||
| first sequence | second sequence
|
||
"""
|
||
sep = [self.sep_token_id]
|
||
cls = [self.cls_token_id]
|
||
|
||
return len(cls + self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1) + sep) * [1]
|
||
|
||
def save_vocabulary(self, save_directory):
|
||
"""Save the tokenizer vocabulary and merge files to a directory."""
|
||
if not os.path.isdir(save_directory):
|
||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||
return
|
||
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
|
||
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
|
||
|
||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||
|
||
index = 0
|
||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||
if index != token_index:
|
||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||
index = token_index
|
||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||
index += 1
|
||
|
||
return vocab_file, merge_file
|