mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Removed interdependency of BERT's Tokenizer in tokenization of prophetnet (#19331)
* removed interdependency of BERTTokenizer in tokenization of prophetnet * fix: style
This commit is contained in:
parent
07e94bf159
commit
512fa41c53
@ -15,11 +15,11 @@
|
||||
|
||||
import collections
|
||||
import os
|
||||
import unicodedata
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
||||
from ...utils import logging
|
||||
from ..bert.tokenization_bert import BasicTokenizer, WordpieceTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -43,6 +43,224 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
|
||||
class BasicTokenizer(object):
|
||||
"""
|
||||
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
||||
|
||||
Args:
|
||||
do_lower_case (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to lowercase the input when tokenizing.
|
||||
never_split (`Iterable`, *optional*):
|
||||
Collection of tokens which will never be split during tokenization. Only has an effect when
|
||||
`do_basic_tokenize=True`
|
||||
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to tokenize Chinese characters.
|
||||
|
||||
This should likely be deactivated for Japanese (see this
|
||||
[issue](https://github.com/huggingface/transformers/issues/328)).
|
||||
strip_accents (`bool`, *optional*):
|
||||
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
||||
value for `lowercase` (as in the original BERT).
|
||||
"""
|
||||
|
||||
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
|
||||
if never_split is None:
|
||||
never_split = []
|
||||
self.do_lower_case = do_lower_case
|
||||
self.never_split = set(never_split)
|
||||
self.tokenize_chinese_chars = tokenize_chinese_chars
|
||||
self.strip_accents = strip_accents
|
||||
|
||||
def tokenize(self, text, never_split=None):
|
||||
"""
|
||||
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
|
||||
WordPieceTokenizer.
|
||||
|
||||
Args:
|
||||
never_split (`List[str]`, *optional*)
|
||||
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
||||
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
|
||||
"""
|
||||
# union() returns a new set by concatenating the two sets.
|
||||
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
if self.tokenize_chinese_chars:
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if token not in never_split:
|
||||
if self.do_lower_case:
|
||||
token = token.lower()
|
||||
if self.strip_accents is not False:
|
||||
token = self._run_strip_accents(token)
|
||||
elif self.strip_accents:
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text, never_split=None):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if never_split is not None and text in never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if (
|
||||
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||||
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||||
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||||
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||||
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||||
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||||
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||||
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||||
): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""
|
||||
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
||||
tokenization using the given vocabulary.
|
||||
|
||||
For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through *BasicTokenizer*.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
|
Loading…
Reference in New Issue
Block a user