mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix train_new_from_iterator
in the case of byte-level tokenizers (#17549)
This commit is contained in:
parent
264128cb9d
commit
ae7bae8fe7
@ -21,6 +21,7 @@ import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import tokenizers.pre_tokenizers as pre_tokenizers_fast
|
||||
from tokenizers import Encoding as EncodingFast
|
||||
from tokenizers import Tokenizer as TokenizerFast
|
||||
from tokenizers.decoders import Decoder as DecoderFast
|
||||
@ -699,6 +700,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
|
||||
if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
|
||||
kwargs["unk_token"] = unk_token
|
||||
if tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel":
|
||||
kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet()
|
||||
|
||||
trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
|
||||
trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
|
||||
|
@ -150,6 +150,7 @@ class BartModelTester:
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.max_position_embeddings = 100
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
@ -140,6 +140,7 @@ class BlenderbotModelTester:
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.max_position_embeddings = 100
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
@ -130,6 +130,11 @@ class DebertaModelTester(object):
|
||||
pos_att_type=self.pos_att_type,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result.loss.size()), [])
|
||||
|
||||
|
@ -166,6 +166,11 @@ class GPT2ModelTester:
|
||||
reorder_and_upcast_attn=reorder_and_upcast_attn,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
|
@ -151,6 +151,11 @@ class GPTNeoModelTester:
|
||||
attention_types=self.attention_types,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
|
@ -155,6 +155,11 @@ class GPTJModelTester:
|
||||
rotary_dim=self.rotary_dim,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
|
@ -116,6 +116,11 @@ class IBertModelTester:
|
||||
quant_mode=True,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
|
@ -163,6 +163,7 @@ class LEDModelTester:
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.max_position_embeddings = 100
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
@ -113,6 +113,11 @@ class LongformerModelTester:
|
||||
attention_window=self.attention_window,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def create_and_check_attention_mask_determinism(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
|
@ -112,6 +112,11 @@ class RobertaModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
|
@ -126,6 +126,11 @@ class YosoModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
|
@ -39,6 +39,7 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.test_rust_tokenizer = True
|
||||
|
||||
model_paths = ["robot-test/dummy-tokenizer-fast", "robot-test/dummy-tokenizer-wordlevel"]
|
||||
self.bytelevel_bpe_model_name = "SaulLu/dummy-tokenizer-bytelevel-bpe"
|
||||
|
||||
# Inclusion of 2 tokenizers to test different types of models (Unigram and WordLevel for the moment)
|
||||
self.tokenizers_list = [(PreTrainedTokenizerFast, model_path, {}) for model_path in model_paths]
|
||||
@ -99,6 +100,15 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
self.tmpdirname = tmpdirname_orig
|
||||
|
||||
def test_training_new_tokenizer_with_bytelevel(self):
|
||||
tokenizer = self.rust_tokenizer_class.from_pretrained(self.bytelevel_bpe_model_name)
|
||||
|
||||
toy_text_iterator = ("a" for _ in range(1000))
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(text_iterator=toy_text_iterator, length=1000, vocab_size=50)
|
||||
|
||||
encoding_ids = new_tokenizer.encode("a🤗")
|
||||
self.assertEqual(encoding_ids, [64, 172, 253, 97, 245])
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
class TokenizerVersioningTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user