DebertaTokenizer Rework closes #10258 (#10703)

* closes #10258

* typo

* reworked deberta test

* implemented the comments from BigBird01 regarding sequence pair encoding of deberta

* Update style

* VOCAB_FILES_NAMES is now a oneliner as suggested by @sgugger

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* added #fmt: on as requested by @sgugger

* Style

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
cronoik 2021-04-01 19:53:53 +02:00 committed by GitHub
parent 30677dc743
commit 57c1749efa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 183 additions and 559 deletions

View File

@ -14,41 +14,34 @@
# limitations under the License.
""" Tokenization class for model DeBERTa."""
import os
import pathlib
import random
import unicodedata
from functools import lru_cache
from typing import Optional, Tuple
from zipfile import ZipFile
from typing import List, Optional
import tqdm
import requests
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils import AddedToken
from ...utils import logging
try:
import regex as re
except ImportError:
raise ImportError("Please install regex with: pip install regex")
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"}
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/bpe_encoder.bin",
"microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/bpe_encoder.bin",
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/bpe_encoder.bin",
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/bpe_encoder.bin",
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/bpe_encoder.bin",
"microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/bpe_encoder.bin",
}
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/vocab.json",
"microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/vocab.json",
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/vocab.json",
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/vocab.json",
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/vocab.json",
"microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/vocab.json",
},
"merges_file": {
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/merges.txt",
"microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/merges.txt",
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/merges.txt",
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/merges.txt",
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/merges.txt",
"microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
@ -65,437 +58,8 @@ PRETRAINED_INIT_CONFIGURATION = {
"microsoft/deberta-large": {"do_lower_case": False},
}
__all__ = ["DebertaTokenizer"]
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode
strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're
at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant
percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode
strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""
Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length
strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors="replace"):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges))))
self.cache = {}
self.random = random.Random(0)
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def split_to_words(self, text):
return list(re.findall(self.pat, text))
def encode(self, text):
bpe_tokens = []
for token in self.split_to_words(text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def get_encoder(encoder, vocab):
return Encoder(
encoder=encoder,
bpe_merges=vocab,
)
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def download_asset(name, tag=None, no_cache=False, cache_dir=None):
_tag = tag
if _tag is None:
_tag = "latest"
if not cache_dir:
cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/")
os.makedirs(cache_dir, exist_ok=True)
output = os.path.join(cache_dir, name)
if os.path.exists(output) and (not no_cache):
return output
repo = "https://api.github.com/repos/microsoft/DeBERTa/releases"
releases = requests.get(repo).json()
if tag and tag != "latest":
release = [r for r in releases if r["name"].lower() == tag.lower()]
if len(release) != 1:
raise Exception(f"{tag} can't be found in the repository.")
else:
release = releases[0]
asset = [s for s in release["assets"] if s["name"].lower() == name.lower()]
if len(asset) != 1:
raise Exception(f"{name} can't be found in the release.")
url = asset[0]["url"]
headers = {}
headers["Accept"] = "application/octet-stream"
resp = requests.get(url, stream=True, headers=headers)
if resp.status_code != 200:
raise Exception(f"Request for {url} return {resp.status_code}, {resp.text}")
try:
with open(output, "wb") as fs:
progress = tqdm(
total=int(resp.headers["Content-Length"]) if "Content-Length" in resp.headers else -1,
ncols=80,
desc=f"Downloading {name}",
)
for c in resp.iter_content(chunk_size=1024 * 1024):
fs.write(c)
progress.update(len(c))
progress.close()
except Exception:
os.remove(output)
raise
return output
def load_vocab(name=None, tag=None, no_cache=False, cache_dir=None):
import torch
if name is None:
name = "bpe_encoder"
model_path = name
if model_path and (not os.path.exists(model_path)) and not (("/" in model_path) or ("\\" in model_path)):
_tag = tag
if _tag is None:
_tag = "latest"
if not cache_dir:
cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/")
os.makedirs(cache_dir, exist_ok=True)
out_dir = os.path.join(cache_dir, name)
model_path = os.path.join(out_dir, "bpe_encoder.bin")
if (not os.path.exists(model_path)) or no_cache:
asset = download_asset(name + ".zip", tag=tag, no_cache=no_cache, cache_dir=cache_dir)
with ZipFile(asset, "r") as zipf:
for zip_info in zipf.infolist():
if zip_info.filename[-1] == "/":
continue
zip_info.filename = os.path.basename(zip_info.filename)
zipf.extract(zip_info, out_dir)
elif not model_path:
return None, None
encoder_state = torch.load(model_path)
return encoder_state
class GPT2Tokenizer(object):
"""
A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer
Args:
vocab_file (:obj:`str`, optional):
The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases
<https://github.com/microsoft/DeBERTa/releases>`_, e.g. "bpe_encoder", default: `None`.
If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file
is a state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used
in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. The difference between our wrapped GPT2
tokenizer and RoBERTa wrapped tokenizer are,
- Special tokens, unlike `RoBERTa` which use `<s>`, `</s>` as the `start` token and `end` token of a
sentence. We use `[CLS]` and `[SEP]` as the `start` and `end` token of input sentence which is the same
as `BERT`.
- We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0,
`[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264
special_tokens (:obj:`list`, optional):
List of special tokens to be added to the end of the vocabulary.
"""
def __init__(self, vocab_file=None, special_tokens=None):
self.pad_token = "[PAD]"
self.sep_token = "[SEP]"
self.unk_token = "[UNK]"
self.cls_token = "[CLS]"
self.symbols = []
self.count = []
self.indices = {}
self.pad_token_id = self.add_symbol(self.pad_token)
self.cls_token_id = self.add_symbol(self.cls_token)
self.sep_token_id = self.add_symbol(self.sep_token)
self.unk_token_id = self.add_symbol(self.unk_token)
self.gpt2_encoder = load_vocab(vocab_file)
self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"])
for w, n in self.gpt2_encoder["dict_map"]:
self.add_symbol(w, n)
self.mask_token = "[MASK]"
self.mask_id = self.add_symbol(self.mask_token)
self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"]
if special_tokens is not None:
for t in special_tokens:
self.add_special_token(t)
self.vocab = self.indices
self.ids_to_tokens = self.symbols
def tokenize(self, text):
"""
Convert an input text to tokens.
Args:
text (:obj:`str`): input text to be tokenized.
Returns:
A list of byte tokens where each token represent the byte id in GPT2 byte dictionary
Example::
>>> tokenizer = GPT2Tokenizer()
>>> text = "Hello world!"
>>> tokens = tokenizer.tokenize(text)
>>> print(tokens)
['15496', '995', '0']
"""
bpe = self._encode(text)
return [t for t in bpe.split(" ") if t]
def convert_tokens_to_ids(self, tokens):
"""
Convert list of tokens to ids
Args:
tokens (:obj:`list<str>`): list of tokens
Returns:
List of ids
"""
return [self.vocab[t] for t in tokens]
def convert_ids_to_tokens(self, ids):
"""
Convert list of ids to tokens
Args:
ids (:obj:`list<int>`): list of ids
Returns:
List of tokens
"""
tokens = []
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
def split_to_words(self, text):
return self.bpe.split_to_words(text)
def decode(self, tokens):
"""
Decode list of tokens to text strings
Args:
tokens (:obj:`list<str>`): list of tokens.
Returns:
Text string corresponds to the input tokens.
Example::
>>> tokenizer = GPT2Tokenizer()
>>> text = "Hello world!"
>>> tokens = tokenizer.tokenize(text)
>>> print(tokens)
['15496', '995', '0']
>>> tokenizer.decode(tokens)
'Hello world!'
"""
return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens])
def add_special_token(self, token):
"""
Adds a special token to the dictionary
Args:
token (:obj:`str`): Tthe new token/word to be added to the vocabulary.
Returns:
The id of new token in the vocabulary.
"""
self.special_tokens.append(token)
return self.add_symbol(token)
def part_of_whole_word(self, token, is_bos=False):
if is_bos:
return True
s = self._decode(token)
if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])):
return False
return not s.startswith(" ")
def sym(self, id):
return self.ids_to_tokens[id]
def id(self, sym):
return self.vocab[sym]
def _encode(self, x: str) -> str:
return " ".join(map(str, self.bpe.encode(x)))
def _decode(self, x: str) -> str:
return self.bpe.decode(map(int, x.split()))
def add_symbol(self, word, n=1):
"""
Adds a word to the dictionary
Args:
word (:obj:`str`): Tthe new token/word to be added to the vocabulary.
n (int, optional): The frequency of the word.
Returns:
The id of the new word.
"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def save_pretrained(self, path: str, filename_prefix: str = None):
import torch
filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]]
if filename_prefix is not None:
filename = filename_prefix + "-" + filename
full_path = os.path.join(path, filename)
torch.save(self.gpt2_encoder, full_path)
return (full_path,)
class DebertaTokenizer(PreTrainedTokenizer):
class DebertaTokenizer(GPT2Tokenizer):
r"""
Constructs a DeBERTa tokenizer, which runs end-to-end tokenization: punctuation splitting + wordpiece
@ -523,70 +87,52 @@ class DebertaTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
def __init__(
self,
vocab_file,
do_lower_case=False,
unk_token="[UNK]",
merges_file,
errors="replace",
bos_token="[CLS]",
eos_token="[SEP]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
unk_token="[UNK]",
pad_token="[PAD]",
mask_token="[MASK]",
add_prefix_space=False,
**kwargs
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__(
do_lower_case=do_lower_case,
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.do_lower_case = do_lower_case
self.gpt2_tokenizer = GPT2Tokenizer(vocab_file)
@property
def vocab_size(self):
return len(self.vocab)
@property
def vocab(self):
return self.gpt2_tokenizer.vocab
def get_vocab(self):
vocab = self.vocab.copy()
vocab.update(self.get_added_vocab())
return vocab
def _tokenize(self, text):
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
if self.do_lower_case:
text = text.lower()
return self.gpt2_tokenizer.tokenize(text)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.gpt2_tokenizer.sym(index) if index < self.vocab_size else self.unk_token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
return self.gpt2_tokenizer.decode(tokens)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A DeBERTa sequence has the following format:
@ -603,14 +149,15 @@ class DebertaTokenizer(PreTrainedTokenizer):
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
@ -626,25 +173,21 @@ class DebertaTokenizer(PreTrainedTokenizer):
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(
map(
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
token_ids_0,
)
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
sequence pair mask has the following format:
@ -668,15 +211,13 @@ class DebertaTokenizer(PreTrainedTokenizer):
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0]
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", False)
if is_split_into_words or add_prefix_space:
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
text = " " + text
return (text, kwargs)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
return self.gpt2_tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix)

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2018 Microsoft, the Hugging Face Team.
# Copyright 2019 Hugging Face inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,61 +14,144 @@
# limitations under the License.
import re
import json
import os
import unittest
from typing import Tuple
from transformers.models.deberta.tokenization_deberta import DebertaTokenizer
from transformers.testing_utils import require_torch
from transformers import DebertaTokenizer
from transformers.models.deberta.tokenization_deberta import VOCAB_FILES_NAMES
from transformers.testing_utils import slow
from .test_tokenization_common import TokenizerTesterMixin
@require_torch
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = DebertaTokenizer
test_rust_tokenizer = False
def setUp(self):
super().setUp()
def get_tokenizer(self, name="microsoft/deberta-base", **kwargs):
return DebertaTokenizer.from_pretrained(name, **kwargs)
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab = [
"l",
"o",
"w",
"e",
"r",
"s",
"t",
"i",
"d",
"n",
"\u0120",
"\u0120l",
"\u0120n",
"\u0120lo",
"\u0120low",
"er",
"\u0120lowest",
"\u0120newer",
"\u0120wider",
"[UNK]",
]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
self.special_tokens_map = {"unk_token": "[UNK]"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"
output_text = "lower newer"
return input_text, output_text
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20) -> Tuple[str, list]:
toks = [
(i, tokenizer.decode([i], clean_up_tokenization_spaces=False))
for i in range(5, min(len(tokenizer), 50260))
]
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
if max_length is not None and len(toks) > max_length:
toks = toks[:max_length]
# toks_str = [t[1] for t in toks]
toks_ids = [t[0] for t in toks]
# Ensure consistency
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
if " " not in output_txt and len(toks_ids) > 1:
output_txt = (
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
+ " "
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
)
if with_prefix_space and not output_txt.startswith(" "):
output_txt = " " + output_txt
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
return output_txt, output_ids
def test_full_tokenizer(self):
tokenizer = self.get_tokenizer("microsoft/deberta-base")
input_str = "UNwant\u00E9d,running"
tokens = tokenizer.tokenize(input_str)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
tokenizer = self.get_tokenizer()
text = "lower newer"
bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
self.assertEqual(tokenizer.decode(token_ids), input_str)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
@slow
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/deberta-base")
text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
encoded_text_from_decode = tokenizer.encode(
"sequence builders", add_special_tokens=True, add_prefix_space=False
)
encoded_pair_from_decode = tokenizer.encode(
"sequence builders", "multi-sequence build", add_special_tokens=True, add_prefix_space=False
)
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode
@slow
def test_tokenizer_integration(self):
tokenizer_classes = [self.tokenizer_class]
if self.test_rust_tokenizer:
tokenizer_classes.append(self.rust_tokenizer_class)
for tokenizer_class in tokenizer_classes:
tokenizer = tokenizer_class.from_pretrained("microsoft/deberta-base")
sequences = [
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
"ALBERT incorporates two parameter reduction techniques",
"The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
]
encoding = tokenizer(sequences, padding=True)
decoded_sequences = [tokenizer.decode(seq, skip_special_tokens=True) for seq in encoding["input_ids"]]
# fmt: off
expected_encoding = {
'input_ids': [
[1, 2118, 11126, 565, 35, 83, 25191, 163, 18854, 13, 12156, 12, 16101, 25376, 13807, 9, 22205, 27893, 1635, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 2118, 11126, 565, 24536, 80, 43797, 4878, 7373, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 133, 78, 65, 16, 10, 3724, 1538, 33183, 11303, 43797, 1938, 4, 870, 24165, 29105, 5, 739, 32644, 33183, 11303, 36173, 88, 80, 650, 7821, 45940, 6, 52, 2559, 5, 1836, 9, 5, 7397, 13171, 31, 5, 1836, 9, 32644, 33183, 11303, 4, 2]
],
'token_type_ids': [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
],
'attention_mask': [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
]
}
# fmt: on
expected_decoded_sequence = [
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
"ALBERT incorporates two parameter reduction techniques",
"The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
]
self.assertDictEqual(encoding.data, expected_encoding)
for expected, decoded in zip(expected_decoded_sequence, decoded_sequences):
self.assertEqual(expected, decoded)