mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
* closes #10258 * typo * reworked deberta test * implemented the comments from BigBird01 regarding sequence pair encoding of deberta * Update style * VOCAB_FILES_NAMES is now a oneliner as suggested by @sgugger Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * added #fmt: on as requested by @sgugger * Style Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
30677dc743
commit
57c1749efa
@ -14,41 +14,34 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Tokenization class for model DeBERTa."""
|
""" Tokenization class for model DeBERTa."""
|
||||||
|
|
||||||
import os
|
from typing import List, Optional
|
||||||
import pathlib
|
|
||||||
import random
|
|
||||||
import unicodedata
|
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
from zipfile import ZipFile
|
|
||||||
|
|
||||||
import tqdm
|
from ...tokenization_utils import AddedToken
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
try:
|
|
||||||
import regex as re
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install regex with: pip install regex")
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"}
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
|
||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
PRETRAINED_VOCAB_FILES_MAP = {
|
||||||
"vocab_file": {
|
"vocab_file": {
|
||||||
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/bpe_encoder.bin",
|
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/vocab.json",
|
||||||
"microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/bpe_encoder.bin",
|
"microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/vocab.json",
|
||||||
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/bpe_encoder.bin",
|
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/vocab.json",
|
||||||
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/bpe_encoder.bin",
|
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/vocab.json",
|
||||||
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/bpe_encoder.bin",
|
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/vocab.json",
|
||||||
"microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/bpe_encoder.bin",
|
"microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/vocab.json",
|
||||||
}
|
},
|
||||||
|
"merges_file": {
|
||||||
|
"microsoft/deberta-base": "https://huggingface.co/microsoft/deberta-base/resolve/main/merges.txt",
|
||||||
|
"microsoft/deberta-large": "https://huggingface.co/microsoft/deberta-large/resolve/main/merges.txt",
|
||||||
|
"microsoft/deberta-xlarge": "https://huggingface.co/microsoft/deberta-xlarge/resolve/main/merges.txt",
|
||||||
|
"microsoft/deberta-base-mnli": "https://huggingface.co/microsoft/deberta-base-mnli/resolve/main/merges.txt",
|
||||||
|
"microsoft/deberta-large-mnli": "https://huggingface.co/microsoft/deberta-large-mnli/resolve/main/merges.txt",
|
||||||
|
"microsoft/deberta-xlarge-mnli": "https://huggingface.co/microsoft/deberta-xlarge-mnli/resolve/main/merges.txt",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||||
@ -65,437 +58,8 @@ PRETRAINED_INIT_CONFIGURATION = {
|
|||||||
"microsoft/deberta-large": {"do_lower_case": False},
|
"microsoft/deberta-large": {"do_lower_case": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
__all__ = ["DebertaTokenizer"]
|
|
||||||
|
|
||||||
|
class DebertaTokenizer(GPT2Tokenizer):
|
||||||
@lru_cache()
|
|
||||||
def bytes_to_unicode():
|
|
||||||
"""
|
|
||||||
Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode
|
|
||||||
strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're
|
|
||||||
at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a signficant
|
|
||||||
percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode
|
|
||||||
strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
|
|
||||||
"""
|
|
||||||
bs = (
|
|
||||||
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
|
||||||
)
|
|
||||||
cs = bs[:]
|
|
||||||
n = 0
|
|
||||||
for b in range(2 ** 8):
|
|
||||||
if b not in bs:
|
|
||||||
bs.append(b)
|
|
||||||
cs.append(2 ** 8 + n)
|
|
||||||
n += 1
|
|
||||||
cs = [chr(n) for n in cs]
|
|
||||||
return dict(zip(bs, cs))
|
|
||||||
|
|
||||||
|
|
||||||
def get_pairs(word):
|
|
||||||
"""
|
|
||||||
Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length
|
|
||||||
strings).
|
|
||||||
"""
|
|
||||||
pairs = set()
|
|
||||||
prev_char = word[0]
|
|
||||||
for char in word[1:]:
|
|
||||||
pairs.add((prev_char, char))
|
|
||||||
prev_char = char
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder:
|
|
||||||
def __init__(self, encoder, bpe_merges, errors="replace"):
|
|
||||||
self.encoder = encoder
|
|
||||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
|
||||||
self.errors = errors # how to handle errors in decoding
|
|
||||||
self.byte_encoder = bytes_to_unicode()
|
|
||||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
|
||||||
self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges))))
|
|
||||||
self.cache = {}
|
|
||||||
self.random = random.Random(0)
|
|
||||||
|
|
||||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
|
||||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
|
||||||
|
|
||||||
def bpe(self, token):
|
|
||||||
if token in self.cache:
|
|
||||||
return self.cache[token]
|
|
||||||
word = tuple(token)
|
|
||||||
pairs = get_pairs(word)
|
|
||||||
|
|
||||||
if not pairs:
|
|
||||||
return token
|
|
||||||
|
|
||||||
while True:
|
|
||||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
|
||||||
if bigram not in self.bpe_ranks:
|
|
||||||
break
|
|
||||||
first, second = bigram
|
|
||||||
new_word = []
|
|
||||||
i = 0
|
|
||||||
while i < len(word):
|
|
||||||
try:
|
|
||||||
j = word.index(first, i)
|
|
||||||
new_word.extend(word[i:j])
|
|
||||||
i = j
|
|
||||||
except Exception:
|
|
||||||
new_word.extend(word[i:])
|
|
||||||
break
|
|
||||||
|
|
||||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
|
||||||
new_word.append(first + second)
|
|
||||||
i += 2
|
|
||||||
else:
|
|
||||||
new_word.append(word[i])
|
|
||||||
i += 1
|
|
||||||
new_word = tuple(new_word)
|
|
||||||
word = new_word
|
|
||||||
if len(word) == 1:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
pairs = get_pairs(word)
|
|
||||||
word = " ".join(word)
|
|
||||||
self.cache[token] = word
|
|
||||||
return word
|
|
||||||
|
|
||||||
def split_to_words(self, text):
|
|
||||||
return list(re.findall(self.pat, text))
|
|
||||||
|
|
||||||
def encode(self, text):
|
|
||||||
bpe_tokens = []
|
|
||||||
for token in self.split_to_words(text):
|
|
||||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
|
||||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
|
|
||||||
return bpe_tokens
|
|
||||||
|
|
||||||
def decode(self, tokens):
|
|
||||||
text = "".join([self.decoder[token] for token in tokens])
|
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def get_encoder(encoder, vocab):
|
|
||||||
return Encoder(
|
|
||||||
encoder=encoder,
|
|
||||||
bpe_merges=vocab,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_whitespace(char):
|
|
||||||
"""Checks whether `chars` is a whitespace character."""
|
|
||||||
# \t, \n, and \r are technically contorl characters but we treat them
|
|
||||||
# as whitespace since they are generally considered as such.
|
|
||||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
|
||||||
return True
|
|
||||||
cat = unicodedata.category(char)
|
|
||||||
if cat == "Zs":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_control(char):
|
|
||||||
"""Checks whether `chars` is a control character."""
|
|
||||||
# These are technically control characters but we count them as whitespace
|
|
||||||
# characters.
|
|
||||||
if char == "\t" or char == "\n" or char == "\r":
|
|
||||||
return False
|
|
||||||
cat = unicodedata.category(char)
|
|
||||||
if cat.startswith("C"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_punctuation(char):
|
|
||||||
"""Checks whether `chars` is a punctuation character."""
|
|
||||||
cp = ord(char)
|
|
||||||
# We treat all non-letter/number ASCII as punctuation.
|
|
||||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
|
||||||
# Punctuation class but we treat them as punctuation anyways, for
|
|
||||||
# consistency.
|
|
||||||
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
|
|
||||||
return True
|
|
||||||
cat = unicodedata.category(char)
|
|
||||||
if cat.startswith("P"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def download_asset(name, tag=None, no_cache=False, cache_dir=None):
|
|
||||||
_tag = tag
|
|
||||||
if _tag is None:
|
|
||||||
_tag = "latest"
|
|
||||||
if not cache_dir:
|
|
||||||
cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
output = os.path.join(cache_dir, name)
|
|
||||||
if os.path.exists(output) and (not no_cache):
|
|
||||||
return output
|
|
||||||
|
|
||||||
repo = "https://api.github.com/repos/microsoft/DeBERTa/releases"
|
|
||||||
releases = requests.get(repo).json()
|
|
||||||
if tag and tag != "latest":
|
|
||||||
release = [r for r in releases if r["name"].lower() == tag.lower()]
|
|
||||||
if len(release) != 1:
|
|
||||||
raise Exception(f"{tag} can't be found in the repository.")
|
|
||||||
else:
|
|
||||||
release = releases[0]
|
|
||||||
asset = [s for s in release["assets"] if s["name"].lower() == name.lower()]
|
|
||||||
if len(asset) != 1:
|
|
||||||
raise Exception(f"{name} can't be found in the release.")
|
|
||||||
url = asset[0]["url"]
|
|
||||||
headers = {}
|
|
||||||
headers["Accept"] = "application/octet-stream"
|
|
||||||
resp = requests.get(url, stream=True, headers=headers)
|
|
||||||
if resp.status_code != 200:
|
|
||||||
raise Exception(f"Request for {url} return {resp.status_code}, {resp.text}")
|
|
||||||
try:
|
|
||||||
with open(output, "wb") as fs:
|
|
||||||
progress = tqdm(
|
|
||||||
total=int(resp.headers["Content-Length"]) if "Content-Length" in resp.headers else -1,
|
|
||||||
ncols=80,
|
|
||||||
desc=f"Downloading {name}",
|
|
||||||
)
|
|
||||||
for c in resp.iter_content(chunk_size=1024 * 1024):
|
|
||||||
fs.write(c)
|
|
||||||
progress.update(len(c))
|
|
||||||
progress.close()
|
|
||||||
except Exception:
|
|
||||||
os.remove(output)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def load_vocab(name=None, tag=None, no_cache=False, cache_dir=None):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if name is None:
|
|
||||||
name = "bpe_encoder"
|
|
||||||
|
|
||||||
model_path = name
|
|
||||||
if model_path and (not os.path.exists(model_path)) and not (("/" in model_path) or ("\\" in model_path)):
|
|
||||||
_tag = tag
|
|
||||||
if _tag is None:
|
|
||||||
_tag = "latest"
|
|
||||||
if not cache_dir:
|
|
||||||
cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
out_dir = os.path.join(cache_dir, name)
|
|
||||||
model_path = os.path.join(out_dir, "bpe_encoder.bin")
|
|
||||||
if (not os.path.exists(model_path)) or no_cache:
|
|
||||||
asset = download_asset(name + ".zip", tag=tag, no_cache=no_cache, cache_dir=cache_dir)
|
|
||||||
with ZipFile(asset, "r") as zipf:
|
|
||||||
for zip_info in zipf.infolist():
|
|
||||||
if zip_info.filename[-1] == "/":
|
|
||||||
continue
|
|
||||||
zip_info.filename = os.path.basename(zip_info.filename)
|
|
||||||
zipf.extract(zip_info, out_dir)
|
|
||||||
elif not model_path:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
encoder_state = torch.load(model_path)
|
|
||||||
return encoder_state
|
|
||||||
|
|
||||||
|
|
||||||
class GPT2Tokenizer(object):
|
|
||||||
"""
|
|
||||||
A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab_file (:obj:`str`, optional):
|
|
||||||
The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases
|
|
||||||
<https://github.com/microsoft/DeBERTa/releases>`_, e.g. "bpe_encoder", default: `None`.
|
|
||||||
|
|
||||||
If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file
|
|
||||||
is a state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used
|
|
||||||
in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. The difference between our wrapped GPT2
|
|
||||||
tokenizer and RoBERTa wrapped tokenizer are,
|
|
||||||
|
|
||||||
- Special tokens, unlike `RoBERTa` which use `<s>`, `</s>` as the `start` token and `end` token of a
|
|
||||||
sentence. We use `[CLS]` and `[SEP]` as the `start` and `end` token of input sentence which is the same
|
|
||||||
as `BERT`.
|
|
||||||
|
|
||||||
- We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0,
|
|
||||||
`[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264
|
|
||||||
|
|
||||||
special_tokens (:obj:`list`, optional):
|
|
||||||
List of special tokens to be added to the end of the vocabulary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vocab_file=None, special_tokens=None):
|
|
||||||
self.pad_token = "[PAD]"
|
|
||||||
self.sep_token = "[SEP]"
|
|
||||||
self.unk_token = "[UNK]"
|
|
||||||
self.cls_token = "[CLS]"
|
|
||||||
|
|
||||||
self.symbols = []
|
|
||||||
self.count = []
|
|
||||||
self.indices = {}
|
|
||||||
self.pad_token_id = self.add_symbol(self.pad_token)
|
|
||||||
self.cls_token_id = self.add_symbol(self.cls_token)
|
|
||||||
self.sep_token_id = self.add_symbol(self.sep_token)
|
|
||||||
self.unk_token_id = self.add_symbol(self.unk_token)
|
|
||||||
|
|
||||||
self.gpt2_encoder = load_vocab(vocab_file)
|
|
||||||
self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"])
|
|
||||||
for w, n in self.gpt2_encoder["dict_map"]:
|
|
||||||
self.add_symbol(w, n)
|
|
||||||
|
|
||||||
self.mask_token = "[MASK]"
|
|
||||||
self.mask_id = self.add_symbol(self.mask_token)
|
|
||||||
self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"]
|
|
||||||
if special_tokens is not None:
|
|
||||||
for t in special_tokens:
|
|
||||||
self.add_special_token(t)
|
|
||||||
|
|
||||||
self.vocab = self.indices
|
|
||||||
self.ids_to_tokens = self.symbols
|
|
||||||
|
|
||||||
def tokenize(self, text):
|
|
||||||
"""
|
|
||||||
Convert an input text to tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (:obj:`str`): input text to be tokenized.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of byte tokens where each token represent the byte id in GPT2 byte dictionary
|
|
||||||
|
|
||||||
Example::
|
|
||||||
>>> tokenizer = GPT2Tokenizer()
|
|
||||||
>>> text = "Hello world!"
|
|
||||||
>>> tokens = tokenizer.tokenize(text)
|
|
||||||
>>> print(tokens)
|
|
||||||
['15496', '995', '0']
|
|
||||||
"""
|
|
||||||
bpe = self._encode(text)
|
|
||||||
|
|
||||||
return [t for t in bpe.split(" ") if t]
|
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
|
||||||
"""
|
|
||||||
Convert list of tokens to ids
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens (:obj:`list<str>`): list of tokens
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ids
|
|
||||||
"""
|
|
||||||
|
|
||||||
return [self.vocab[t] for t in tokens]
|
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids):
|
|
||||||
"""
|
|
||||||
Convert list of ids to tokens
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids (:obj:`list<int>`): list of ids
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
for i in ids:
|
|
||||||
tokens.append(self.ids_to_tokens[i])
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def split_to_words(self, text):
|
|
||||||
return self.bpe.split_to_words(text)
|
|
||||||
|
|
||||||
def decode(self, tokens):
|
|
||||||
"""
|
|
||||||
Decode list of tokens to text strings
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens (:obj:`list<str>`): list of tokens.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Text string corresponds to the input tokens.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
>>> tokenizer = GPT2Tokenizer()
|
|
||||||
>>> text = "Hello world!"
|
|
||||||
>>> tokens = tokenizer.tokenize(text)
|
|
||||||
>>> print(tokens)
|
|
||||||
['15496', '995', '0']
|
|
||||||
>>> tokenizer.decode(tokens)
|
|
||||||
'Hello world!'
|
|
||||||
"""
|
|
||||||
return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens])
|
|
||||||
|
|
||||||
def add_special_token(self, token):
|
|
||||||
"""
|
|
||||||
Adds a special token to the dictionary
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The id of new token in the vocabulary.
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.special_tokens.append(token)
|
|
||||||
return self.add_symbol(token)
|
|
||||||
|
|
||||||
def part_of_whole_word(self, token, is_bos=False):
|
|
||||||
if is_bos:
|
|
||||||
return True
|
|
||||||
s = self._decode(token)
|
|
||||||
if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return not s.startswith(" ")
|
|
||||||
|
|
||||||
def sym(self, id):
|
|
||||||
return self.ids_to_tokens[id]
|
|
||||||
|
|
||||||
def id(self, sym):
|
|
||||||
return self.vocab[sym]
|
|
||||||
|
|
||||||
def _encode(self, x: str) -> str:
|
|
||||||
return " ".join(map(str, self.bpe.encode(x)))
|
|
||||||
|
|
||||||
def _decode(self, x: str) -> str:
|
|
||||||
return self.bpe.decode(map(int, x.split()))
|
|
||||||
|
|
||||||
def add_symbol(self, word, n=1):
|
|
||||||
"""
|
|
||||||
Adds a word to the dictionary
|
|
||||||
|
|
||||||
Args:
|
|
||||||
word (:obj:`str`): Tthe new token/word to be added to the vocabulary.
|
|
||||||
n (int, optional): The frequency of the word.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The id of the new word.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if word in self.indices:
|
|
||||||
idx = self.indices[word]
|
|
||||||
self.count[idx] = self.count[idx] + n
|
|
||||||
return idx
|
|
||||||
else:
|
|
||||||
idx = len(self.symbols)
|
|
||||||
self.indices[word] = idx
|
|
||||||
self.symbols.append(word)
|
|
||||||
self.count.append(n)
|
|
||||||
return idx
|
|
||||||
|
|
||||||
def save_pretrained(self, path: str, filename_prefix: str = None):
|
|
||||||
import torch
|
|
||||||
|
|
||||||
filename = VOCAB_FILES_NAMES[list(VOCAB_FILES_NAMES.keys())[0]]
|
|
||||||
if filename_prefix is not None:
|
|
||||||
filename = filename_prefix + "-" + filename
|
|
||||||
full_path = os.path.join(path, filename)
|
|
||||||
torch.save(self.gpt2_encoder, full_path)
|
|
||||||
return (full_path,)
|
|
||||||
|
|
||||||
|
|
||||||
class DebertaTokenizer(PreTrainedTokenizer):
|
|
||||||
r"""
|
r"""
|
||||||
Constructs a DeBERTa tokenizer, which runs end-to-end tokenization: punctuation splitting + wordpiece
|
Constructs a DeBERTa tokenizer, which runs end-to-end tokenization: punctuation splitting + wordpiece
|
||||||
|
|
||||||
@ -523,70 +87,52 @@ class DebertaTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
vocab_files_names = VOCAB_FILES_NAMES
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_file,
|
vocab_file,
|
||||||
do_lower_case=False,
|
merges_file,
|
||||||
unk_token="[UNK]",
|
errors="replace",
|
||||||
|
bos_token="[CLS]",
|
||||||
|
eos_token="[SEP]",
|
||||||
sep_token="[SEP]",
|
sep_token="[SEP]",
|
||||||
pad_token="[PAD]",
|
|
||||||
cls_token="[CLS]",
|
cls_token="[CLS]",
|
||||||
|
unk_token="[UNK]",
|
||||||
|
pad_token="[PAD]",
|
||||||
mask_token="[MASK]",
|
mask_token="[MASK]",
|
||||||
|
add_prefix_space=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
||||||
|
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
||||||
|
sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
|
||||||
|
cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
|
||||||
|
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
||||||
|
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
||||||
|
|
||||||
|
# Mask token behave like a normal word, i.e. include the space before it
|
||||||
|
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
do_lower_case=do_lower_case,
|
vocab_file=vocab_file,
|
||||||
|
merges_file=merges_file,
|
||||||
|
errors=errors,
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
unk_token=unk_token,
|
unk_token=unk_token,
|
||||||
sep_token=sep_token,
|
sep_token=sep_token,
|
||||||
pad_token=pad_token,
|
|
||||||
cls_token=cls_token,
|
cls_token=cls_token,
|
||||||
|
pad_token=pad_token,
|
||||||
mask_token=mask_token,
|
mask_token=mask_token,
|
||||||
|
add_prefix_space=add_prefix_space,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.path.isfile(vocab_file):
|
def build_inputs_with_special_tokens(
|
||||||
raise ValueError(
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
|
) -> List[int]:
|
||||||
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
|
||||||
)
|
|
||||||
self.do_lower_case = do_lower_case
|
|
||||||
self.gpt2_tokenizer = GPT2Tokenizer(vocab_file)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_size(self):
|
|
||||||
return len(self.vocab)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab(self):
|
|
||||||
return self.gpt2_tokenizer.vocab
|
|
||||||
|
|
||||||
def get_vocab(self):
|
|
||||||
vocab = self.vocab.copy()
|
|
||||||
vocab.update(self.get_added_vocab())
|
|
||||||
return vocab
|
|
||||||
|
|
||||||
def _tokenize(self, text):
|
|
||||||
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
|
||||||
if self.do_lower_case:
|
|
||||||
text = text.lower()
|
|
||||||
return self.gpt2_tokenizer.tokenize(text)
|
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
|
||||||
""" Converts a token (str) in an id using the vocab. """
|
|
||||||
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
|
||||||
|
|
||||||
def _convert_id_to_token(self, index):
|
|
||||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
|
||||||
return self.gpt2_tokenizer.sym(index) if index < self.vocab_size else self.unk_token
|
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
|
||||||
""" Converts a sequence of tokens (string) in a single string. """
|
|
||||||
return self.gpt2_tokenizer.decode(tokens)
|
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
|
||||||
"""
|
"""
|
||||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||||
adding special tokens. A DeBERTa sequence has the following format:
|
adding special tokens. A DeBERTa sequence has the following format:
|
||||||
@ -603,14 +149,15 @@ class DebertaTokenizer(PreTrainedTokenizer):
|
|||||||
Returns:
|
Returns:
|
||||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
return cls + token_ids_0 + sep + token_ids_1 + sep
|
return cls + token_ids_0 + sep + token_ids_1 + sep
|
||||||
|
|
||||||
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||||
@ -626,25 +173,21 @@ class DebertaTokenizer(PreTrainedTokenizer):
|
|||||||
Returns:
|
Returns:
|
||||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if already_has_special_tokens:
|
if already_has_special_tokens:
|
||||||
if token_ids_1 is not None:
|
if token_ids_1 is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You should not supply a second sequence if the provided sequence of "
|
"You should not supply a second sequence if the provided sequence of "
|
||||||
"ids is already formatted with special tokens for the model."
|
"ids is already formatted with special tokens for the model."
|
||||||
)
|
)
|
||||||
return list(
|
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||||
map(
|
|
||||||
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
|
|
||||||
token_ids_0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if token_ids_1 is not None:
|
if token_ids_1 is None:
|
||||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||||
|
|
||||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
|
def create_token_type_ids_from_sequences(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
|
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A DeBERTa
|
||||||
sequence pair mask has the following format:
|
sequence pair mask has the following format:
|
||||||
@ -668,15 +211,13 @@ class DebertaTokenizer(PreTrainedTokenizer):
|
|||||||
"""
|
"""
|
||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
cls = [self.cls_token_id]
|
cls = [self.cls_token_id]
|
||||||
|
|
||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
return len(cls + token_ids_0 + sep) * [0]
|
return len(cls + token_ids_0 + sep) * [0]
|
||||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
return len(cls + token_ids_0 + sep + token_ids_1 + sep) * [0]
|
||||||
|
|
||||||
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
||||||
add_prefix_space = kwargs.pop("add_prefix_space", False)
|
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
|
||||||
if is_split_into_words or add_prefix_space:
|
if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
|
||||||
text = " " + text
|
text = " " + text
|
||||||
return (text, kwargs)
|
return (text, kwargs)
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
|
||||||
return self.gpt2_tokenizer.save_pretrained(save_directory, filename_prefix=filename_prefix)
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 Microsoft, the Hugging Face Team.
|
# Copyright 2019 Hugging Face inc.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -14,61 +14,144 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import re
|
import json
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from transformers.models.deberta.tokenization_deberta import DebertaTokenizer
|
from transformers import DebertaTokenizer
|
||||||
from transformers.testing_utils import require_torch
|
from transformers.models.deberta.tokenization_deberta import VOCAB_FILES_NAMES
|
||||||
|
from transformers.testing_utils import slow
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
tokenizer_class = DebertaTokenizer
|
tokenizer_class = DebertaTokenizer
|
||||||
|
test_rust_tokenizer = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
def get_tokenizer(self, name="microsoft/deberta-base", **kwargs):
|
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||||
return DebertaTokenizer.from_pretrained(name, **kwargs)
|
vocab = [
|
||||||
|
"l",
|
||||||
|
"o",
|
||||||
|
"w",
|
||||||
|
"e",
|
||||||
|
"r",
|
||||||
|
"s",
|
||||||
|
"t",
|
||||||
|
"i",
|
||||||
|
"d",
|
||||||
|
"n",
|
||||||
|
"\u0120",
|
||||||
|
"\u0120l",
|
||||||
|
"\u0120n",
|
||||||
|
"\u0120lo",
|
||||||
|
"\u0120low",
|
||||||
|
"er",
|
||||||
|
"\u0120lowest",
|
||||||
|
"\u0120newer",
|
||||||
|
"\u0120wider",
|
||||||
|
"[UNK]",
|
||||||
|
]
|
||||||
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
|
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||||||
|
self.special_tokens_map = {"unk_token": "[UNK]"}
|
||||||
|
|
||||||
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||||
|
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||||||
|
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||||
|
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||||
|
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||||
|
fp.write("\n".join(merges))
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
kwargs.update(self.special_tokens_map)
|
||||||
|
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
def get_input_output_texts(self, tokenizer):
|
def get_input_output_texts(self, tokenizer):
|
||||||
input_text = "lower newer"
|
input_text = "lower newer"
|
||||||
output_text = "lower newer"
|
output_text = "lower newer"
|
||||||
return input_text, output_text
|
return input_text, output_text
|
||||||
|
|
||||||
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20) -> Tuple[str, list]:
|
|
||||||
toks = [
|
|
||||||
(i, tokenizer.decode([i], clean_up_tokenization_spaces=False))
|
|
||||||
for i in range(5, min(len(tokenizer), 50260))
|
|
||||||
]
|
|
||||||
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
|
|
||||||
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
|
|
||||||
if max_length is not None and len(toks) > max_length:
|
|
||||||
toks = toks[:max_length]
|
|
||||||
# toks_str = [t[1] for t in toks]
|
|
||||||
toks_ids = [t[0] for t in toks]
|
|
||||||
|
|
||||||
# Ensure consistency
|
|
||||||
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
|
|
||||||
if " " not in output_txt and len(toks_ids) > 1:
|
|
||||||
output_txt = (
|
|
||||||
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
|
|
||||||
+ " "
|
|
||||||
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
|
|
||||||
)
|
|
||||||
if with_prefix_space and not output_txt.startswith(" "):
|
|
||||||
output_txt = " " + output_txt
|
|
||||||
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
|
|
||||||
return output_txt, output_ids
|
|
||||||
|
|
||||||
def test_full_tokenizer(self):
|
def test_full_tokenizer(self):
|
||||||
tokenizer = self.get_tokenizer("microsoft/deberta-base")
|
tokenizer = self.get_tokenizer()
|
||||||
input_str = "UNwant\u00E9d,running"
|
text = "lower newer"
|
||||||
tokens = tokenizer.tokenize(input_str)
|
bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"]
|
||||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
tokens = tokenizer.tokenize(text)
|
||||||
|
self.assertListEqual(tokens, bpe_tokens)
|
||||||
|
|
||||||
self.assertEqual(tokenizer.decode(token_ids), input_str)
|
input_tokens = tokens + [tokenizer.unk_token]
|
||||||
|
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
|
||||||
|
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_sequence_builders(self):
|
||||||
|
tokenizer = self.tokenizer_class.from_pretrained("microsoft/deberta-base")
|
||||||
|
|
||||||
|
text = tokenizer.encode("sequence builders", add_special_tokens=False)
|
||||||
|
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
|
||||||
|
|
||||||
|
encoded_text_from_decode = tokenizer.encode(
|
||||||
|
"sequence builders", add_special_tokens=True, add_prefix_space=False
|
||||||
|
)
|
||||||
|
encoded_pair_from_decode = tokenizer.encode(
|
||||||
|
"sequence builders", "multi-sequence build", add_special_tokens=True, add_prefix_space=False
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||||
|
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||||
|
|
||||||
|
assert encoded_sentence == encoded_text_from_decode
|
||||||
|
assert encoded_pair == encoded_pair_from_decode
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_tokenizer_integration(self):
|
||||||
|
tokenizer_classes = [self.tokenizer_class]
|
||||||
|
if self.test_rust_tokenizer:
|
||||||
|
tokenizer_classes.append(self.rust_tokenizer_class)
|
||||||
|
|
||||||
|
for tokenizer_class in tokenizer_classes:
|
||||||
|
tokenizer = tokenizer_class.from_pretrained("microsoft/deberta-base")
|
||||||
|
|
||||||
|
sequences = [
|
||||||
|
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
|
||||||
|
"ALBERT incorporates two parameter reduction techniques",
|
||||||
|
"The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
|
||||||
|
]
|
||||||
|
|
||||||
|
encoding = tokenizer(sequences, padding=True)
|
||||||
|
decoded_sequences = [tokenizer.decode(seq, skip_special_tokens=True) for seq in encoding["input_ids"]]
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
expected_encoding = {
|
||||||
|
'input_ids': [
|
||||||
|
[1, 2118, 11126, 565, 35, 83, 25191, 163, 18854, 13, 12156, 12, 16101, 25376, 13807, 9, 22205, 27893, 1635, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 2118, 11126, 565, 24536, 80, 43797, 4878, 7373, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 133, 78, 65, 16, 10, 3724, 1538, 33183, 11303, 43797, 1938, 4, 870, 24165, 29105, 5, 739, 32644, 33183, 11303, 36173, 88, 80, 650, 7821, 45940, 6, 52, 2559, 5, 1836, 9, 5, 7397, 13171, 31, 5, 1836, 9, 32644, 33183, 11303, 4, 2]
|
||||||
|
],
|
||||||
|
'token_type_ids': [
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
||||||
|
],
|
||||||
|
'attention_mask': [
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
expected_decoded_sequence = [
|
||||||
|
"ALBERT: A Lite BERT for Self-supervised Learning of Language Representations",
|
||||||
|
"ALBERT incorporates two parameter reduction techniques",
|
||||||
|
"The first one is a factorized embedding parameterization. By decomposing the large vocabulary embedding matrix into two small matrices, we separate the size of the hidden layers from the size of vocabulary embedding.",
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertDictEqual(encoding.data, expected_encoding)
|
||||||
|
|
||||||
|
for expected, decoded in zip(expected_decoded_sequence, decoded_sequences):
|
||||||
|
self.assertEqual(expected, decoded)
|
||||||
|
Loading…
Reference in New Issue
Block a user