mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* closes #10258 * typo * reworked deberta test * implemented the comments from BigBird01 regarding sequence pair encoding of deberta * Update style * VOCAB_FILES_NAMES is now a oneliner as suggested by @sgugger Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * added #fmt: on as requested by @sgugger * Style Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
30677dc743
commit
57c1749efa
@ -14,41 +14,34 @@
|
||||
# limitations under the License.
|
||||
""" Tokenization class for model DeBERTa."""
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import unicodedata
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Tuple
|
||||
from zipfile import ZipFile
|
||||
from typing import List, Optional
|
||||
|
||||
import tqdm
|
||||
|
||||
import requests
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils import AddedToken
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
try:
|
||||
import regex as re
|
||||
except ImportError:
|
||||
raise ImportError("Please install regex with: pip install regex")
|
||||
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"}
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/bpe_encoder.bin",
|
||||
"microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/bpe_encoder.bin",
|
||||
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/bpe_encoder.bin",
|
||||
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/bpe_encoder.bin",
|
||||
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/bpe_encoder.bin",
|
||||
"microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/bpe_encoder.bin",
|
||||
}
|
||||
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/vocab.json",
|
||||
"microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/vocab.json",
|
||||
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/vocab.json",
|
||||
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/vocab.json",
|
||||
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/vocab.json",
|
||||
"microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/vocab.json",
|
||||
},
|
||||
"merges_file": {
|
||||
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/merges.txt",
|
||||
"microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/merges.txt",
|
||||
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/merges.txt",
|
||||
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/merges.txt",
|
||||
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/merges.txt",
|
||||
"microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/merges.txt",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
@ -65,437 +58,8 @@ PRETRAINED_INIT_CONFIGURATION = {
|
||||
"microsoft/deberta-large": {"do_lower_case": False},
|
||||
}
|
||||
|
||||
__all__ = ["DebertaTokenizer"]
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode
|
||||
strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're
|
||||
at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant
|
||||
percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode
|
||||
strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2 ** 8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2 ** 8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length
|
||||
strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, encoder, bpe_merges, errors="replace"):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
self.random = random.Random(0)
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except Exception:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def split_to_words(self, text):
|
||||
return list(re.findall(self.pat, text))
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in self.split_to_words(text):
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = "".join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
||||
return text
|
||||
|
||||
|
||||
def get_encoder(encoder, vocab):
|
||||
return Encoder(
|
||||
encoder=encoder,
|
||||
bpe_merges=vocab,
|
||||
)
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def download_asset(name, tag=None, no_cache=False, cache_dir=None):
|
||||
_tag = tag
|
||||
if _tag is None:
|
||||
_tag = "latest"
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
output = os.path.join(cache_dir, name)
|
||||
if os.path.exists(output) and (not no_cache):
|
||||
return output
|
||||
|
||||
repo = "https://api.github.com/repos/microsoft/DeBERTa/releases"
|
||||
releases = requests.get(repo).json()
|
||||
if tag and tag != "latest":
|
||||
release = [r for r in releases if r["name"].lower() == tag.lower()]
|
||||
if len(release) != 1:
|
||||
raise Exception(f"{tag} can't be found in the repository.")
|
||||
else:
|
||||
release = releases[0]
|
||||
asset = [s for s in release["assets"] if s["name"].lower() == name.lower()]
|
||||
if len(asset) != 1:
|
||||
raise Exception(f"{name} can't be found in the release.")
|
||||
url = asset[0]["url"]
|
||||
headers = {}
|
||||
headers["Accept"] = "application/octet-stream"
|
||||
resp = requests.get(url, stream=True, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Request for {url} return {resp.status_code}, {resp.text}")
|
||||
try:
|
||||
with open(output, "wb") as fs:
|
||||
progress = tqdm(
|
||||
total=int(resp.headers["Content-Length"]) if "Content-Length" in resp.headers else -1,
|
||||
ncols=80,
|
||||
desc=f"Downloading {name}",
|
||||
)
|
||||
for c in resp.iter_content(chunk_size=1024 * 1024):
|
||||
fs.write(c)
|
||||
progress.update(len(c))
|
||||
progress.close()
|
||||
except Exception:
|
||||
os.remove(output)
|
||||
raise
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def load_vocab(name=None, tag=None, no_cache=False, cache_dir=None):
|
||||
import torch
|
||||
|
||||
if name is None:
|
||||
name = "bpe_encoder"
|
||||
|
||||
model_path = name
|
||||
if model_path and (not os.path.exists(model_path)) and not (("/" in model_path) or ("\\" in model_path)):
|
||||
_tag = tag
|
||||
if _tag is None:
|
||||
_tag = "latest"
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
out_dir = os.path.join(cache_dir, name)
|
||||
model_path = os.path.join(out_dir, "bpe_encoder.bin")
|
||||
if (not os.path.exists(model_path)) or no_cache:
|
||||
asset = download_asset(name + ".zip", tag=tag, no_cache=no_cache, cache_dir=cache_dir)
|
||||
with ZipFile(asset, "r") as zipf:
|
||||
for zip_info in zipf.infolist():
|
||||
if zip_info.filename[-1] == "/":
|
||||
continue
|
||||
zip_info.filename = os.path.basename(zip_info.filename)
|
||||
zipf.extract(zip_info, out_dir)
|
||||
elif not model_path:
|
||||
return None, None
|
||||
|
||||
encoder_state = torch.load(model_path)
|
||||
return encoder_state
|
||||
|
||||
|
||||
class GPT2Tokenizer(object):
|
||||
"""
|
||||
A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`, optional):
|
||||
The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases
|
||||
<https://github.com/microsoft/DeBERTa/releases>`_, e.g. "bpe_encoder", default: `None`.
|
||||
|
||||
If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file
|
||||
is a state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used
|
||||
in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. The difference between our wrapped GPT2
|
||||
tokenizer and RoBERTa wrapped tokenizer are,
|
||||
|
||||
- Special tokens, unlike `RoBERTa` which use `<s>`, `</s>` as the `start` token and `end` token of a
|
||||
sentence. We use `[CLS]` and `[SEP]` as the `start` and `end` token of input sentence which is the same
|
||||
as `BERT`.
|
||||
|
||||
- We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0,
|
||||
`[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264
|
||||
|
||||
special_tokens (:obj:`list`, optional):
|
||||
List of special tokens to be added to the end of the vocabulary.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_file=None, special_tokens=None):
|
||||
self.pad_token = "[PAD]"
|
||||
self.sep_token = "[SEP]"
|
||||
self.unk_token = "[UNK]"
|
||||
self.cls_token = "[CLS]"
|
||||
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.pad_token_id = self.add_symbol(self.pad_token)
|
||||
self.cls_token_id = self.add_symbol(self.cls_token)
|
||||
self.sep_token_id = self.add_symbol(self.sep_token)
|
||||
self.unk_token_id = self.add_symbol(self.unk_token)
|
||||
|
||||
self.gpt2_encoder = load_vocab(vocab_file)
|
||||
self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"])
|
||||
for w, n in self.gpt2_encoder["dict_map"]:
|
||||
self.add_symbol(w, n)
|
||||
|
||||
self.mask_token = "[MASK]"
|
||||
self.mask_id = self.add_symbol(self.mask_token)
|
||||
self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"]
|
||||
if special_tokens is not None:
|
||||
for t in special_tokens:
|
||||
self.add_special_token(t)
|
||||
|
||||
self.vocab = self.indices
|
||||
self.ids_to_tokens = self.symbols
|
||||
|
||||
def tokenize(self, text):
|
||||
"""
|
||||
Convert an input text to tokens.
|
||||
|
||||
Args:
|
||||
text (:obj:`str`): input text to be tokenized.
|
||||
|
||||
Returns:
|
||||
A list of byte tokens where each token represent the byte id in GPT2 byte dictionary
|
||||
|
||||
Example::
|
||||
>>> tokenizer = GPT2Tokenizer()
|
||||
>>> text = "Hello world!"
|
||||
>>> tokens = tokenizer.tokenize(text)
|
||||
>>> print(tokens)
|
||||
['15496', '995', '0']
|
||||
"""
|
||||
bpe = self._encode(text)
|
||||
|
||||
return [t for t in bpe.split(" ") if t]
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""
|
||||
Convert list of tokens to ids
|
||||
|
||||
Args:
|
||||
tokens (:obj:`list<str>`): list of tokens
|
||||
|
||||
Returns:
|
||||
List of ids
|
||||
"""
|
||||
|
||||
return [self.vocab[t] for t in tokens]
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
"""
|
||||
Convert list of ids to tokens
|
||||
|
||||
Args:
|
||||
ids (:obj:`list<int>`): list of ids
|
||||
|
||||
Returns:
|
||||
List of tokens
|
||||
"""
|
||||
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
return tokens
|
||||
|
||||
def split_to_words(self, text):
|
||||
return self.bpe.split_to_words(text)
|
||||
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Decode list of tokens to text strings
|
||||
|
||||
Args:
|
||||
tokens (:obj:`list<str>`): list of tokens.
|
||||
|
||||
Returns:
|
||||
Text string corresponds to the input tokens.
|
||||
|
||||
Example::
|
||||
>>> tokenizer = GPT2Tokenizer()
|
||||
>>> text = "Hello world!"
|
||||
>>> tokens = tokenizer.tokenize(text)
|
||||
>>> print(tokens)
|
||||
['15496', '995', '0']
|
||||
>>> tokenizer.decode(tokens)
|
||||
'Hello world!'
|
||||
"""
|
||||
return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens])
|
||||
|
||||
def add_special_token(self, token):
|
||||
"""
|
||||
Adds a special token to the dictionary
|
||||
|
||||
Args:
|
||||
token (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
||||
|
||||
Returns:
|
||||
The id of new token in the vocabulary.
|
||||
|
||||
"""
|
||||
self.special_tokens.append(token)
|
||||
return self.add_symbol(token)
|
||||
|
||||
def part_of_whole_word(self, token, is_bos=False):
|
||||
if is_bos:
|
||||
return True
|
||||
s = self._decode(token)
|
||||
if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])):
|
||||
return False
|
||||
|
||||
return not s.startswith(" ")
|
||||
|
||||
def sym(self, id):
|
||||
return self.ids_to_tokens[id]
|
||||
|
||||
def id(self, sym):
|
||||
return self.vocab[sym]
|
||||
|
||||
def _encode(self, x: str) -> str:
|
||||
return " ".join(map(str, self.bpe.encode(x)))
|
||||
|
||||
def _decode(self, x: str) -> str:
|
||||
return self.bpe.decode(map(int, x.split()))
|
||||
|
||||
def add_symbol(self, word, n=1):
|
||||
"""
|
||||
Adds a word to the dictionary
|
||||
|
||||
Args:
|
||||
word (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
||||
n (int, optional): The frequency of the word.
|
||||
|
||||
Returns:
|
||||
The id of the new word.
|
||||
|
||||
"""
|
||||
if word in self.indices:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def save_pretrained(self, path: str, filename_prefix: str = None):
|
||||
import torch
|
||||
|
||||
filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]]
|
||||
if filename_prefix is not None:
|
||||
filename = filename_prefix + "-" + filename
|
||||
full_path = os.path.join(path, filename)
|
||||
torch.save(self.gpt2_encoder, full_path)
|
||||
return (full_path,)
|
||||
|
||||
|
||||
class DebertaTokenizer(PreTrainedTokenizer):
|
||||
class DebertaTokenizer(GPT2Tokenizer):
|
||||
r"""
|
||||
Constructs a DeBERTa tokenizer, which runs end-to-end tokenization: punctuation splitting + wordpiece
|
||||
|
||||
@ -523,70 +87,52 @@ class DebertaTokenizer(PreTrainedTokenizer):
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=False,
|
||||
unk_token="[UNK]",
|
||||
merges_file,
|
||||
errors="replace",
|
||||
bos_token="[CLS]",
|
||||
eos_token="[SEP]",
|
||||
sep_token="[SEP]",
|
||||
pad_token="[PAD]",
|
||||
cls_token="[CLS]",
|
||||
unk_token="[UNK]",
|
||||
pad_token="[PAD]",
|
||||
mask_token="[MASK]",
|
||||
add_prefix_space=False,
|
||||
**kwargs
|
||||
):
|
||||
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
||||
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
||||
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
|
||||
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
|
||||
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
||||
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
||||
|
||||
# Mask token behave like a normal word, i.e. include the space before it
|
||||
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||
|
||||
super().__init__(
|
||||
do_lower_case=do_lower_case,
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
errors=errors,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
sep_token=sep_token,
|
||||
pad_token=pad_token,
|
||||
cls_token=cls_token,
|
||||
pad_token=pad_token,
|
||||
mask_token=mask_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.do_lower_case = do_lower_case
|
||||
self.gpt2_tokenizer = GPT2Tokenizer(vocab_file)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.vocab)
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self.gpt2_tokenizer.vocab
|
||||
|
||||
def get_vocab(self):
|
||||
vocab = self.vocab.copy()
|
||||
vocab.update(self.get_added_vocab())
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text):
|
||||
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
||||
if self.do_lower_case:
|
||||
text = text.lower()
|
||||
return self.gpt2_tokenizer.tokenize(text)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.gpt2_tokenizer.sym(index) if index < self.vocab_size else self.unk_token
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
return self.gpt2_tokenizer.decode(tokens)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A DeBERTa sequence has the following format:
|
||||
@ -603,14 +149,15 @@ class DebertaTokenizer(PreTrainedTokenizer):
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
sep = [self.sep_token_id]
|
||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||
|
||||
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||
@ -626,25 +173,21 @@ class DebertaTokenizer(PreTrainedTokenizer):
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formatted with special tokens for the model."
|
||||
)
|
||||
return list(
|
||||
map(
|
||||
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
|
||||
token_ids_0,
|
||||
)
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
|
||||
if token_ids_1 is not None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
|
||||
sequence pair mask has the following format:
|
||||
@ -668,15 +211,13 @@ class DebertaTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0]
|
||||
|
||||
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
||||
add_prefix_space = kwargs.pop("add_prefix_space", False)
|
||||
if is_split_into_words or add_prefix_space:
|
||||
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
|
||||
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
|
||||
text = " " + text
|
||||
return (text, kwargs)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
return self.gpt2_tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix)
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Microsoft, the Hugging Face Team.
|
||||
# Copyright 2019 Hugging Face inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -14,61 +14,144 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import re
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
from transformers.models.deberta.tokenization_deberta import DebertaTokenizer
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers import DebertaTokenizer
|
||||
from transformers.models.deberta.tokenization_deberta import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
@require_torch
|
||||
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = DebertaTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def get_tokenizer(self, name="microsoft/deberta-base", **kwargs):
|
||||
return DebertaTokenizer.from_pretrained(name, **kwargs)
|
||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||
vocab = [
|
||||
"l",
|
||||
"o",
|
||||
"w",
|
||||
"e",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"i",
|
||||
"d",
|
||||
"n",
|
||||
"\u0120",
|
||||
"\u0120l",
|
||||
"\u0120n",
|
||||
"\u0120lo",
|
||||
"\u0120low",
|
||||
"er",
|
||||
"\u0120lowest",
|
||||
"\u0120newer",
|
||||
"\u0120wider",
|
||||
"[UNK]",
|
||||
]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||
self.special_tokens_map = {"unk_token": "[UNK]"}
|
||||
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20) -> Tuple[str, list]:
|
||||
toks = [
|
||||
(i, tokenizer.decode([i], clean_up_tokenization_spaces=False))
|
||||
for i in range(5, min(len(tokenizer), 50260))
|
||||
]
|
||||
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
|
||||
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
|
||||
if max_length is not None and len(toks) > max_length:
|
||||
toks = toks[:max_length]
|
||||
# toks_str = [t[1] for t in toks]
|
||||
toks_ids = [t[0] for t in toks]
|
||||
|
||||
# Ensure consistency
|
||||
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
|
||||
if " " not in output_txt and len(toks_ids) > 1:
|
||||
output_txt = (
|
||||
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
|
||||
+ " "
|
||||
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
|
||||
)
|
||||
if with_prefix_space and not output_txt.startswith(" "):
|
||||
output_txt = " " + output_txt
|
||||
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
|
||||
return output_txt, output_ids
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer("microsoft/deberta-base")
|
||||
input_str = "UNwant\u00E9d,running"
|
||||
tokens = tokenizer.tokenize(input_str)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
tokenizer = self.get_tokenizer()
|
||||
text = "lower newer"
|
||||
bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
self.assertEqual(tokenizer.decode(token_ids), input_str)
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("microsoft/deberta-base")
|
||||
|
||||
text = tokenizer.encode("sequence builders", add_special_tokens=False)
|
||||
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
|
||||
|
||||
encoded_text_from_decode = tokenizer.encode(
|
||||
"sequence builders", add_special_tokens=True, add_prefix_space=False
|
||||
)
|
||||
encoded_pair_from_decode = tokenizer.encode(
|
||||
"sequence builders", "multi-sequence build", add_special_tokens=True, add_prefix_space=False
|
||||
)
|
||||
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == encoded_text_from_decode
|
||||
assert encoded_pair == encoded_pair_from_decode
|
||||
|
||||
@slow
|
||||
def test_tokenizer_integration(self):
|
||||
tokenizer_classes = [self.tokenizer_class]
|
||||
if self.test_rust_tokenizer:
|
||||
tokenizer_classes.append(self.rust_tokenizer_class)
|
||||
|
||||
for tokenizer_class in tokenizer_classes:
|
||||
tokenizer = tokenizer_class.from_pretrained("microsoft/deberta-base")
|
||||
|
||||
sequences = [
|
||||
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
|
||||
"ALBERT incorporates two parameter reduction techniques",
|
||||
"The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
|
||||
]
|
||||
|
||||
encoding = tokenizer(sequences, padding=True)
|
||||
decoded_sequences = [tokenizer.decode(seq, skip_special_tokens=True) for seq in encoding["input_ids"]]
|
||||
|
||||
# fmt: off
|
||||
expected_encoding = {
|
||||
'input_ids': [
|
||||
[1, 2118, 11126, 565, 35, 83, 25191, 163, 18854, 13, 12156, 12, 16101, 25376, 13807, 9, 22205, 27893, 1635, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 2118, 11126, 565, 24536, 80, 43797, 4878, 7373, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 133, 78, 65, 16, 10, 3724, 1538, 33183, 11303, 43797, 1938, 4, 870, 24165, 29105, 5, 739, 32644, 33183, 11303, 36173, 88, 80, 650, 7821, 45940, 6, 52, 2559, 5, 1836, 9, 5, 7397, 13171, 31, 5, 1836, 9, 32644, 33183, 11303, 4, 2]
|
||||
],
|
||||
'token_type_ids': [
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
],
|
||||
'attention_mask': [
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||
]
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
expected_decoded_sequence = [
|
||||
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
|
||||
"ALBERT incorporates two parameter reduction techniques",
|
||||
"The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
|
||||
]
|
||||
|
||||
self.assertDictEqual(encoding.data, expected_encoding)
|
||||
|
||||
for expected, decoded in zip(expected_decoded_sequence, decoded_sequences):
|
||||
self.assertEqual(expected, decoded)
|
||||
|
Loading…
Reference in New Issue
Block a user