Fix llama tokenizer (#22402)

* draft

* update tokenization limma and conversion script

* more udpates

* initial commit

* style

* default pad to None

* draft tokenization tests

* update test

* update tokenization tests

* nits

* update

* versioning test

* major fix

* fix more testst

* finish fixing special masks

* last nit

* more nits

* add encode decode tests

* add more

* fix token type ids

* style
This commit is contained in:
Arthur 2023-04-03 15:07:32 +02:00 committed by GitHub
parent 9eae4aa576
commit c0f99b4d2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 480 additions and 62 deletions

View File

@ -20,7 +20,7 @@ import shutil
import torch
from transformers import LlamaConfig, LlamaForCausalLM
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
"""
@ -233,19 +233,9 @@ def write_model(model_path, input_base_path, model_size):
def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
os.makedirs(tokenizer_path, exist_ok=True)
write_json({}, os.path.join(tokenizer_path, "special_tokens_map.json"))
write_json(
{
"bos_token": "",
"eos_token": "",
"model_max_length": int(1e30),
"tokenizer_class": "LlamaTokenizer",
"unk_token": "",
},
os.path.join(tokenizer_path, "tokenizer_config.json"),
)
shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
# Initialize the tokenizer based on the `spm` model
tokenizer = LlamaTokenizer(input_tokenizer_path)
tokenizer.save_pretrained(tokenizer_path)
def main():

View File

@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
@ -33,7 +33,17 @@ logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = {}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
},
"tokenizer_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"hf-internal-testing/llama-tokenizer": 2048,
}
class LlamaTokenizer(PreTrainedTokenizer):
@ -47,6 +57,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
@ -55,51 +66,50 @@ class LlamaTokenizer(PreTrainedTokenizer):
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
decode_with_prefix_space=False,
clean_up_tokenization_spaces=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.decode_with_prefix_space = decode_with_prefix_space
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
self._no_prefix_space_tokens = None
""" Initialisation"""
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
@property
def no_prefix_space_tokens(self):
if self._no_prefix_space_tokens is None:
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("")}
return self._no_prefix_space_tokens
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
@property
def bos_token_id(self) -> Optional[int]:
return self.sp_model.bos_id()
@property
def eos_token_id(self) -> Optional[int]:
return self.sp_model.eos_id()
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
@ -119,21 +129,15 @@ class LlamaTokenizer(PreTrainedTokenizer):
token = self.sp_model.IdToPiece(index)
return token
def _maybe_add_prefix_space(self, tokens, decoded):
if tokens and tokens[0] not in self.no_prefix_space_tokens:
return " " + decoded
else:
return decoded
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for token in tokens:
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
if not prev_is_special and i != 0:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
@ -142,7 +146,6 @@ class LlamaTokenizer(PreTrainedTokenizer):
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
return out_string
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
@ -173,18 +176,13 @@ class LlamaTokenizer(PreTrainedTokenizer):
return (out_vocab_file,)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if self.add_bos_token:
bos_token_ids = [self.bos_token_id]
else:
bos_token_ids = []
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_ids + token_ids_0
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + token_ids_1
if self.add_eos_token:
output = output + [self.eos_token_id]
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
@ -211,28 +209,46 @@ class LlamaTokenizer(PreTrainedTokenizer):
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return (
bos_token_id
+ ([0] * len(token_ids_0))
+ eos_token_id
+ bos_token_id
+ ([0] * len(token_ids_1))
+ eos_token_id
)
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
if token_ids_1 is None, only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of IDs.
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
eos = [self.eos_token_id]
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

View File

@ -0,0 +1,412 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import unittest
from datasets import load_dataset
from transformers import (
SPIECE_UNDERLINE,
AddedToken,
LlamaTokenizer,
is_torch_available,
)
from transformers.testing_utils import (
get_tests_dir,
nested_simplify,
require_sentencepiece,
require_tokenizers,
require_torch,
slow,
)
from ...test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available():
pass
@require_sentencepiece
@require_tokenizers
class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = LlamaTokenizer
test_rust_tokenizer = False
test_sentencepiece = True
from_pretrained_kwargs = {}
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = LlamaTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.save_pretrained(self.tmpdirname)
def test_full_tokenizer(self):
tokenizer = LlamaTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[285, 46, 10, 170, 382],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
".",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids,
[8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4],
)
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
[
SPIECE_UNDERLINE + "I",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
",",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
".",
],
)
@unittest.skip("Let's wait for the fast tokenizer!")
def test_save_pretrained(self):
self.tokenizers_list += (self.rust_tokenizer_class, "hf-internal-testing/llama-tokenizer", {})
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
tmpdirname2 = tempfile.mkdtemp()
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2)
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
# Checks it save with the same files + the tokenizer.json file for the fast one
self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
tokenizer_r_files = tuple(f for f in tokenizer_r_files if "tokenizer.json" not in f)
self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files)
# Checks everything loads correctly in the same way
tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
# Check special tokens are set accordingly on Rust and Python
for key in tokenizer_pp.special_tokens_map:
self.assertTrue(hasattr(tokenizer_rp, key))
shutil.rmtree(tmpdirname2)
# Save tokenizer rust, legacy_format=True
tmpdirname2 = tempfile.mkdtemp()
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=True)
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
# Checks it save with the same files
self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files)
# Checks everything loads correctly in the same way
tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
# Check special tokens are set accordingly on Rust and Python
for key in tokenizer_pp.special_tokens_map:
self.assertTrue(hasattr(tokenizer_rp, key))
shutil.rmtree(tmpdirname2)
# Save tokenizer rust, legacy_format=False
tmpdirname2 = tempfile.mkdtemp()
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=False)
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
# Checks it saved the tokenizer.json file
self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
# Checks everything loads correctly in the same way
tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
# Check special tokens are set accordingly on Rust and Python
for key in tokenizer_pp.special_tokens_map:
self.assertTrue(hasattr(tokenizer_rp, key))
shutil.rmtree(tmpdirname2)
@require_torch
def test_batch_tokenization(self):
if not self.test_seq2seq:
return
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
# Longer text that will definitely require truncation.
text = [
" UN Chief Says There Is No Military Solution in Syria",
" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for"
" Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons"
" will only worsen the violence and misery for millions of people.",
]
try:
batch = tokenizer(
text=text,
max_length=3,
max_target_length=10,
return_tensors="pt",
)
except NotImplementedError:
return
self.assertEqual(batch.input_ids.shape[1], 3)
# max_target_length will default to max_length if not specified
batch = tokenizer(text, max_length=3, return_tensors="pt")
self.assertEqual(batch.input_ids.shape[1], 3)
batch_encoder_only = tokenizer(text=text, max_length=3, max_target_length=10, return_tensors="pt")
self.assertEqual(batch_encoder_only.input_ids.shape[1], 3)
self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
self.assertNotIn("decoder_input_ids", batch_encoder_only)
@unittest.skip("Unfortunately way too slow to build a BPE with SentencePiece.")
def test_save_slow_from_fast_and_reload_fast(self):
pass
def test_special_tokens_initialization(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
added_tokens = [AddedToken("<special>", lstrip=True)]
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
r_output = tokenizer_r.encode("Hey this is a <special> token")
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
self.assertTrue(special_token_id in r_output)
if self.test_slow_tokenizer:
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
pretrained_name,
additional_special_tokens=added_tokens,
**kwargs, # , from_slow=True <- unfortunately too slow to convert
)
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
p_output = tokenizer_p.encode("Hey this is a <special> token")
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
self.assertEqual(p_output, r_output)
self.assertEqual(cr_output, r_output)
self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in cr_output)
@slow
def test_tokenizer_integration(self):
# fmt: off
expected_encoding = {'input_ids': [[1, 4103, 689, 414, 313, 24784, 368, 2998, 408, 282, 3637, 25350, 29899, 9067, 414, 322, 282, 3637, 25350, 29899, 1457, 3018, 1312, 29899, 2151, 29897, 8128, 2498, 29899, 15503, 4220, 6956, 1973, 313, 13635, 29911, 29892, 402, 7982, 29899, 29906, 29892, 1528, 13635, 29911, 29874, 29892, 1060, 26369, 29892, 6652, 309, 29933, 814, 29892, 1060, 29931, 6779, 11410, 363, 18385, 17088, 7634, 11235, 313, 25103, 29965, 29897, 322, 18385, 17088, 28203, 313, 25103, 29954, 29897, 411, 975, 29871, 29941, 29906, 29974, 758, 3018, 1312, 4733, 297, 29871, 29896, 29900, 29900, 29974, 10276, 322, 6483, 1006, 3372, 3097, 1546, 435, 1165, 29892, 10772, 29911, 25350, 322, 323, 6073, 17907, 29889], [1, 350, 20161, 338, 8688, 304, 758, 29899, 14968, 6483, 21000, 8684, 284, 22540, 515, 443, 29880, 24025, 1426, 491, 14002, 368, 4195, 292, 373, 1716, 2175, 322, 1492, 3030, 297, 599, 15359, 29889], [1, 450, 4996, 17354, 1701, 29916, 432, 17204, 975, 278, 17366, 11203, 29889]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
# fmt: on
self.tokenizer_integration_test_util(
expected_encoding=expected_encoding,
model_name="hf-internal-testing/llama-tokenizer",
revision="0984d03108b1a041ed679bd253b6519b7e1a4778",
padding=False,
)
@require_torch
@require_sentencepiece
@require_tokenizers
class LlamaIntegrationTest(unittest.TestCase):
checkpoint_name = "hf-internal-testing/llama-tokenizer"
@classmethod
def setUpClass(cls):
cls.tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(cls.checkpoint_name)
cls.rust_tokenizer = cls.tokenizer # TODO @narsil replace with the rust one
cls.pad_token_id = 1
return cls
@require_torch
def integration_tests(self):
inputs = self.tokenizer(
["The following string should be properly encoded: Hello.", "But ird and ปี ird ด"],
return_tensors="pt",
)
self.assertEqual(
nested_simplify(inputs),
{
"input_ids": [
[1, 450, 1494, 1347, 881, 367, 6284, 18511, 29901, 15043, 29889],
[1, 1205, 29871, 1823, 322, 29871, 31010, 30691, 1678, 1823, 1678, 30718],
],
"attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
},
)
def test_simple_encode_decode(self):
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
self.assertEqual(pyth_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243])
self.assertEqual(rust_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243])
self.assertEqual(pyth_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test")
self.assertEqual(rust_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test")
# bytefallback showcase
self.assertEqual(pyth_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392])
self.assertEqual(rust_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392])
self.assertEqual(
pyth_tokenizer.decode(
[1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True
),
"生活的真谛是",
)
self.assertEqual(
rust_tokenizer.decode(
[1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True
),
"生活的真谛是",
)
# Inner spaces showcase
self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043])
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043])
self.assertEqual(pyth_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello")
self.assertEqual(rust_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello")
self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043])
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043])
self.assertEqual(pyth_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello")
self.assertEqual(rust_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello")
self.assertEqual(pyth_tokenizer.encode(""), [1])
self.assertEqual(rust_tokenizer.encode(""), [1])
self.assertEqual(pyth_tokenizer.encode(" "), [1, 259])
self.assertEqual(rust_tokenizer.encode(" "), [1, 259])
self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678])
self.assertEqual(rust_tokenizer.encode(" "), [1, 1678])
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
self.assertEqual(pyth_tokenizer.encode(""), [1])
self.assertEqual(rust_tokenizer.encode(""), [1])
self.assertEqual(pyth_tokenizer.decode([869]), ".")
self.assertEqual(rust_tokenizer.decode([869]), ".")
self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .")
self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .")
@unittest.skipIf(
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
)
def test_integration_test_xnli(self):
import tqdm
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
dataset = load_dataset("code_x_glue_ct_code_to_text", "go")
for item in tqdm.tqdm(dataset["validation"]):
string = item["code"]
encoded1 = pyth_tokenizer.encode(string)
encoded2 = rust_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = pyth_tokenizer.decode(encoded1)
decoded2 = rust_tokenizer.decode(encoded2)
self.assertEqual(decoded1, decoded2)
dataset = load_dataset("xnli", "all_languages")
for item in tqdm.tqdm(dataset["train"]):
for string in item["premise"].values():
encoded1 = pyth_tokenizer.encode(string)
encoded2 = rust_tokenizer.encode(string)
self.assertEqual(encoded1, encoded2)
decoded1 = pyth_tokenizer.decode(encoded1)
decoded2 = rust_tokenizer.decode(encoded2)
self.assertEqual(decoded1, decoded2)

View File

@ -3932,7 +3932,7 @@ class TokenizerTesterMixin:
tokenizer_fast.save_pretrained(tmp_dir_2)
tokenizer = BertTokenizer.from_pretrained(tmp_dir_2)
assert tokenizer_fast.clean_up_tokenization_spaces is False
assert tokenizer.clean_up_tokenization_spaces is False
decoded = tokenizer.decode(tokens)
assert decoded == "[CLS] this shouldn ' t be ! he ' ll go . [SEP]"