VLM: special multimodal Tokenizer (#34461)

* kinda works

* update

* add tests

* update

* use special tokens in processors

* typo

* fix copies

* fix

* fix moshi after rebase

* update

* fix tests

* update

* Update docs/source/en/main_classes/tokenizer.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* update docs

* test for load time adding tokens

* fix some more tests which are now fetched better

* one more fix

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay 2024-11-04 16:37:51 +01:00 committed by GitHub
parent ef976a7e18
commit 187439c3fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 248 additions and 335 deletions

View File

@ -51,6 +51,25 @@ token space (e.g., getting the index of the token comprising a given character o
to a given token). to a given token).
# Multimodal Tokenizer
Apart from that each tokenizer can be a "multimodal" tokenizer which means that the tokenizer will hold all relevant special tokens
as part of tokenizer attributes for easier access. For example, if the tokenizer is loaded from a vision-language model like LLaVA, you will
be able to access `tokenizer.image_token_id` to obtain the special image token used as a placeholder.
To enable extra special tokens for any type of tokenizer, you have to add the following lines and save the tokenizer. Extra special tokens do not
have to be modality related and can ne anything that the model often needs access to. In the below code, tokenizer at `output_dir` will have direct access
to three more special tokens.
```python
vision_tokenizer = AutoTokenizer.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
extra_special_tokens={"image_token": "<image>", "boi_token": "<image_start>", "eoi_token": "<image_end>"}
)
print(vision_tokenizer.image_token, vision_tokenizer.image_token_id)
("<image>", 32000)
```
## PreTrainedTokenizer ## PreTrainedTokenizer
[[autodoc]] PreTrainedTokenizer [[autodoc]] PreTrainedTokenizer

View File

@ -443,7 +443,7 @@ def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int]
return torch.stack(examples, dim=0) return torch.stack(examples, dim=0)
# If yes, check if we have a `pad_token`. # If yes, check if we have a `pad_token`.
if tokenizer._pad_token is None: if tokenizer.pad_token is None:
raise ValueError( raise ValueError(
"You are attempting to pad samples but the tokenizer you are using" "You are attempting to pad samples but the tokenizer you are using"
f" ({tokenizer.__class__.__name__}) does not have a pad token." f" ({tokenizer.__class__.__name__}) does not have a pad token."
@ -477,7 +477,7 @@ def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = N
return tf.stack(examples, axis=0) return tf.stack(examples, axis=0)
# If yes, check if we have a `pad_token`. # If yes, check if we have a `pad_token`.
if tokenizer._pad_token is None: if tokenizer.pad_token is None:
raise ValueError( raise ValueError(
"You are attempting to pad samples but the tokenizer you are using" "You are attempting to pad samples but the tokenizer you are using"
f" ({tokenizer.__class__.__name__}) does not have a pad token." f" ({tokenizer.__class__.__name__}) does not have a pad token."
@ -513,7 +513,7 @@ def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int]
return np.stack(examples, axis=0) return np.stack(examples, axis=0)
# If yes, check if we have a `pad_token`. # If yes, check if we have a `pad_token`.
if tokenizer._pad_token is None: if tokenizer.pad_token is None:
raise ValueError( raise ValueError(
"You are attempting to pad samples but the tokenizer you are using" "You are attempting to pad samples but the tokenizer you are using"
f" ({tokenizer.__class__.__name__}) does not have a pad token." f" ({tokenizer.__class__.__name__}) does not have a pad token."
@ -1090,7 +1090,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
] ]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None: if self.tokenizer.pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id) padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0) probability_matrix.masked_fill_(padding_mask, value=0.0)
@ -1131,7 +1131,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
] ]
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool) masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
if self.tokenizer._pad_token is not None: if self.tokenizer.pad_token is not None:
padding_mask = inputs == self.tokenizer.pad_token_id padding_mask = inputs == self.tokenizer.pad_token_id
masked_indices = masked_indices & ~padding_mask masked_indices = masked_indices & ~padding_mask
@ -1170,7 +1170,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
] ]
masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0 masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
if self.tokenizer._pad_token is not None: if self.tokenizer.pad_token is not None:
padding_mask = labels == self.tokenizer.pad_token_id padding_mask = labels == self.tokenizer.pad_token_id
masked_indices[padding_mask] = 0 masked_indices[padding_mask] = 0
@ -1251,13 +1251,13 @@ class DataCollatorForSOP(DataCollatorForLanguageModeling):
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
] ]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None: if self.tokenizer.pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id) padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0) probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool() masked_indices = torch.bernoulli(probability_matrix).bool()
# probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value # probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
attention_mask = (~masked_indices).float() attention_mask = (~masked_indices).float()
if self.tokenizer._pad_token is not None: if self.tokenizer.pad_token is not None:
attention_padding_mask = labels.eq(self.tokenizer.pad_token_id) attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
attention_mask.masked_fill_(attention_padding_mask, value=1.0) attention_mask.masked_fill_(attention_padding_mask, value=1.0)
labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
@ -1367,7 +1367,7 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
dtype=torch.bool, dtype=torch.bool,
) )
masked_indices.masked_fill_(special_tokens_mask, value=0.0) masked_indices.masked_fill_(special_tokens_mask, value=0.0)
if self.tokenizer._pad_token is not None: if self.tokenizer.pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id) padding_mask = labels.eq(self.tokenizer.pad_token_id)
masked_indices.masked_fill_(padding_mask, value=0.0) masked_indices.masked_fill_(padding_mask, value=0.0)
@ -1471,7 +1471,7 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
) )
special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool) special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
masked_indices = masked_indices & ~special_tokens_mask masked_indices = masked_indices & ~special_tokens_mask
if self.tokenizer._pad_token is not None: if self.tokenizer.pad_token is not None:
padding_mask = labels == self.tokenizer.pad_token_id padding_mask = labels == self.tokenizer.pad_token_id
masked_indices = masked_indices & ~padding_mask masked_indices = masked_indices & ~padding_mask
@ -1571,7 +1571,7 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
dtype=bool, dtype=bool,
) )
masked_indices[special_tokens_mask] = 0 masked_indices[special_tokens_mask] = 0
if self.tokenizer._pad_token is not None: if self.tokenizer.pad_token is not None:
padding_mask = labels == self.tokenizer.pad_token_id padding_mask = labels == self.tokenizer.pad_token_id
masked_indices[padding_mask] = 0.0 masked_indices[padding_mask] = 0.0

View File

@ -74,8 +74,11 @@ class Blip2Processor(ProcessorMixin):
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs): def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
tokenizer.return_token_type_ids = False tokenizer.return_token_type_ids = False
self.current_processor = image_processor self.current_processor = image_processor
self.image_token = AddedToken("<image>", normalized=False, special=True) if not hasattr(tokenizer, "image_token"):
tokenizer.add_tokens([self.image_token], special_tokens=True) self.image_token = AddedToken("<image>", normalized=False, special=True)
tokenizer.add_tokens([self.image_token], special_tokens=True)
else:
self.image_token = tokenizer.image_token
self.num_query_tokens = num_query_tokens self.num_query_tokens = num_query_tokens
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)

View File

@ -66,9 +66,12 @@ class ChameleonProcessor(ProcessorMixin):
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"): def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
self.image_seq_length = image_seq_length self.image_seq_length = image_seq_length
self.image_token = image_token self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.image_start_token = "<racm3:break>" # fixed tokens for start and end, so can hardcode self.image_start_token = (
self.image_end_token = "<eoss>" tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
) # fixed tokens for start and end, so can hardcode
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)
def __call__( def __call__(

View File

@ -138,7 +138,7 @@ class GemmaTokenizer(PreTrainedTokenizer):
return state return state
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__ = d self.__dict__.update(d)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto) self.sp_model.LoadFromSerializedProto(self.sp_model_proto)

View File

@ -219,7 +219,11 @@ class IdeficsProcessor(ProcessorMixin):
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor self.current_processor = self.image_processor
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) self.image_token_id = (
tokenizer.image_token_id
if hasattr(tokenizer, "image_token")
else tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
)
self.default_image_dims = ( self.default_image_dims = (
self.image_processor.image_num_channels, self.image_processor.image_num_channels,

View File

@ -95,15 +95,18 @@ class Idefics2Processor(ProcessorMixin):
if tokenizer is None: if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.") raise ValueError("You need to specify a `tokenizer`.")
self.fake_image_token = AddedToken("<fake_token_around_image>", normalized=False, special=True) if not hasattr(tokenizer, "image_token"):
self.image_token = AddedToken("<image>", normalized=False, special=True) self.fake_image_token = AddedToken("<fake_token_around_image>", normalized=False, special=True)
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True) self.image_token = AddedToken("<image>", normalized=False, special=True)
self.image_seq_len = image_seq_len tokens_to_add = {"additional_special_tokens": [self.fake_image_token, self.image_token]}
tokenizer.add_special_tokens(tokens_to_add)
else:
self.fake_image_token = tokenizer.image_boundary_token
self.image_token = tokenizer.image_token
tokens_to_add = { self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True)
"additional_special_tokens": [self.fake_image_token, self.image_token, self.end_of_utterance_token] tokenizer.add_special_tokens({"additional_special_tokens": [self.end_of_utterance_token]})
} self.image_seq_len = image_seq_len
tokenizer.add_special_tokens(tokens_to_add)
super().__init__(image_processor, tokenizer, chat_template=chat_template) super().__init__(image_processor, tokenizer, chat_template=chat_template)

View File

@ -78,8 +78,11 @@ class InstructBlipProcessor(ProcessorMixin):
qformer_tokenizer_class = "AutoTokenizer" qformer_tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs): def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
self.image_token = AddedToken("<image>", normalized=False, special=True) if not hasattr(tokenizer, "image_token"):
tokenizer.add_tokens([self.image_token], special_tokens=True) self.image_token = AddedToken("<image>", normalized=False, special=True)
tokenizer.add_tokens([self.image_token], special_tokens=True)
else:
self.image_token = tokenizer.image_token
self.num_query_tokens = num_query_tokens self.num_query_tokens = num_query_tokens
super().__init__(image_processor, tokenizer, qformer_tokenizer) super().__init__(image_processor, tokenizer, qformer_tokenizer)

View File

@ -63,8 +63,11 @@ class InstructBlipVideoProcessor(ProcessorMixin):
qformer_tokenizer_class = "AutoTokenizer" qformer_tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs): def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
self.video_token = AddedToken("<video>", normalized=False, special=True) if not hasattr(tokenizer, "video_token"):
tokenizer.add_tokens([self.video_token], special_tokens=True) self.video_token = AddedToken("<video>", normalized=False, special=True)
tokenizer.add_tokens([self.video_token], special_tokens=True)
else:
self.video_token = tokenizer.video_token
self.num_query_tokens = num_query_tokens self.num_query_tokens = num_query_tokens
super().__init__(image_processor, tokenizer, qformer_tokenizer) super().__init__(image_processor, tokenizer, qformer_tokenizer)

View File

@ -297,7 +297,7 @@ class LayoutXLMTokenizer(PreTrainedTokenizer):
return state return state
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__ = d self.__dict__.update(d)
# for backward compatibility # for backward compatibility
if not hasattr(self, "sp_model_kwargs"): if not hasattr(self, "sp_model_kwargs"):

View File

@ -214,7 +214,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
return state return state
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__ = d self.__dict__.update(d)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto) self.sp_model.LoadFromSerializedProto(self.sp_model_proto)

View File

@ -77,7 +77,7 @@ class LlavaProcessor(ProcessorMixin):
): ):
self.patch_size = patch_size self.patch_size = patch_size
self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
super().__init__(image_processor, tokenizer, chat_template=chat_template) super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(

View File

@ -80,7 +80,7 @@ class LlavaNextProcessor(ProcessorMixin):
): ):
self.patch_size = patch_size self.patch_size = patch_size
self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
super().__init__(image_processor, tokenizer, chat_template=chat_template) super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(

View File

@ -82,8 +82,8 @@ class LlavaNextVideoProcessor(ProcessorMixin):
): ):
self.patch_size = patch_size self.patch_size = patch_size
self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = video_token self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template) super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(

View File

@ -96,8 +96,8 @@ class LlavaOnevisionProcessor(ProcessorMixin):
): ):
self.num_image_tokens = num_image_tokens self.num_image_tokens = num_image_tokens
self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = video_token self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
def __call__( def __call__(

View File

@ -213,8 +213,13 @@ class MllamaProcessor(ProcessorMixin):
tokenizer_class = "PreTrainedTokenizerFast" tokenizer_class = "PreTrainedTokenizerFast"
def __init__(self, image_processor, tokenizer): def __init__(self, image_processor, tokenizer):
self.image_token = "<|image|>" if not hasattr(tokenizer, "image_token"):
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) self.image_token = "<|image|>"
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
else:
self.image_token = tokenizer.image_token
self.image_token_id = tokenizer.image_token_id
self.python_token = "<|python_tag|>" self.python_token = "<|python_tag|>"
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token) self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
self.bos_token = tokenizer.bos_token self.bos_token = tokenizer.bos_token

View File

@ -160,11 +160,15 @@ class PaliGemmaProcessor(ProcessorMixin):
self.image_seq_length = image_processor.image_seq_length self.image_seq_length = image_processor.image_seq_length
image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True) if not hasattr(tokenizer, "image_token"):
tokens_to_add = {"additional_special_tokens": [image_token]} image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
tokenizer.add_special_tokens(tokens_to_add) tokens_to_add = {"additional_special_tokens": [image_token]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
else:
self.image_token_id = tokenizer.image_token_id
tokenizer.add_tokens(EXTRA_TOKENS) tokenizer.add_tokens(EXTRA_TOKENS)
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
tokenizer.add_bos_token = False tokenizer.add_bos_token = False
tokenizer.add_eos_token = False tokenizer.add_eos_token = False

View File

@ -61,6 +61,8 @@ class Qwen2VLProcessor(ProcessorMixin):
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
super().__init__(image_processor, tokenizer, chat_template=chat_template) super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(
@ -132,23 +134,23 @@ class Qwen2VLProcessor(ProcessorMixin):
merge_length = self.image_processor.merge_size**2 merge_length = self.image_processor.merge_size**2
index = 0 index = 0
for i in range(len(text)): for i in range(len(text)):
while "<|image_pad|>" in text[i]: while self.image_token in text[i]:
text[i] = text[i].replace( text[i] = text[i].replace(
"<|image_pad|>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1 self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
) )
index += 1 index += 1
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") text[i] = text[i].replace("<|placeholder|>", self.image_token)
if video_grid_thw is not None: if video_grid_thw is not None:
merge_length = self.image_processor.merge_size**2 merge_length = self.image_processor.merge_size**2
index = 0 index = 0
for i in range(len(text)): for i in range(len(text)):
while "<|video_pad|>" in text[i]: while self.video_token in text[i]:
text[i] = text[i].replace( text[i] = text[i].replace(
"<|video_pad|>", "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1 self.video_token, "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
) )
index += 1 index += 1
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>") text[i] = text[i].replace("<|placeholder|>", self.video_token)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])

View File

@ -412,7 +412,7 @@ class UdopTokenizer(PreTrainedTokenizer):
return state return state
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__ = d self.__dict__.update(d)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)

View File

@ -71,8 +71,8 @@ class VideoLlavaProcessor(ProcessorMixin):
): ):
self.patch_size = patch_size self.patch_size = patch_size
self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = video_token self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
super().__init__(image_processor, tokenizer, chat_template=chat_template) super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__( def __call__(

View File

@ -577,7 +577,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
token_index = current_vocab[token.content] token_index = current_vocab[token.content]
if token.special and str(token) not in self.all_special_tokens: if token.special and str(token) not in self.all_special_tokens:
self._additional_special_tokens.append(token) self._special_tokens_map["additional_special_tokens"].append(token)
# the setter automatically updates the reverse map # the setter automatically updates the reverse map
self._added_tokens_decoder[token_index] = token self._added_tokens_decoder[token_index] = token
self._added_tokens_encoder[token.content] = token_index self._added_tokens_encoder[token.content] = token_index

View File

@ -861,16 +861,10 @@ class SpecialTokensMixin:
] ]
def __init__(self, verbose=False, **kwargs): def __init__(self, verbose=False, **kwargs):
self._bos_token = None
self._eos_token = None
self._unk_token = None
self._sep_token = None
self._pad_token = None
self._cls_token = None
self._mask_token = None
self._pad_token_type_id = 0 self._pad_token_type_id = 0
self._additional_special_tokens = []
self.verbose = verbose self.verbose = verbose
self._special_tokens_map = {attr: None for attr in self.SPECIAL_TOKENS_ATTRIBUTES}
self._special_tokens_map["additional_special_tokens"] = [] # for BC where it defaults to empty list
# We directly set the hidden value to allow initialization with special tokens # We directly set the hidden value to allow initialization with special tokens
# which are not yet in the vocabulary. Necessary for serialization/de-serialization # which are not yet in the vocabulary. Necessary for serialization/de-serialization
@ -932,7 +926,7 @@ class SpecialTokensMixin:
assign the index of the `unk_token` to them). assign the index of the `unk_token` to them).
replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`): replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
If `True`, the existing list of additional special tokens will be replaced by the list provided in If `True`, the existing list of additional special tokens will be replaced by the list provided in
`special_tokens_dict`. Otherwise, `self._additional_special_tokens` is just extended. In the former `special_tokens_dict`. Otherwise, `self._special_tokens_map["additional_special_tokens"]` is just extended. In the former
case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
`added_tokens_encoder` and `added_tokens_decoder`. This means that the previous `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
@ -983,7 +977,7 @@ class SpecialTokensMixin:
if replace_additional_special_tokens and len(to_add) > 0: if replace_additional_special_tokens and len(to_add) > 0:
setattr(self, key, list(to_add)) setattr(self, key, list(to_add))
else: else:
self._additional_special_tokens.extend(to_add) self._special_tokens_map["additional_special_tokens"].extend(to_add)
added_tokens += to_add added_tokens += to_add
else: else:
@ -1053,192 +1047,6 @@ class SpecialTokensMixin:
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
raise NotImplementedError raise NotImplementedError
@property
def bos_token(self) -> str:
"""
`str`: Beginning of sentence token. Log an error if used while not having been set.
"""
if self._bos_token is None:
if self.verbose:
logger.error("Using bos_token, but it is not set yet.")
return None
return str(self._bos_token)
@property
def eos_token(self) -> str:
"""
`str`: End of sentence token. Log an error if used while not having been set.
"""
if self._eos_token is None:
if self.verbose:
logger.error("Using eos_token, but it is not set yet.")
return None
return str(self._eos_token)
@property
def unk_token(self) -> str:
"""
`str`: Unknown token. Log an error if used while not having been set.
"""
if self._unk_token is None:
if self.verbose:
logger.error("Using unk_token, but it is not set yet.")
return None
return str(self._unk_token)
@property
def sep_token(self) -> str:
"""
`str`: Separation token, to separate context and query in an input sequence. Log an error if used while not
having been set.
"""
if self._sep_token is None:
if self.verbose:
logger.error("Using sep_token, but it is not set yet.")
return None
return str(self._sep_token)
@property
def pad_token(self) -> str:
"""
`str`: Padding token. Log an error if used while not having been set.
"""
if self._pad_token is None:
if self.verbose:
logger.error("Using pad_token, but it is not set yet.")
return None
return str(self._pad_token)
@property
def cls_token(self) -> str:
"""
`str`: Classification token, to extract a summary of an input sequence leveraging self-attention along the full
depth of the model. Log an error if used while not having been set.
"""
if self._cls_token is None:
if self.verbose:
logger.error("Using cls_token, but it is not set yet.")
return None
return str(self._cls_token)
@property
def mask_token(self) -> str:
"""
`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
having been set.
"""
if self._mask_token is None:
if self.verbose:
logger.error("Using mask_token, but it is not set yet.")
return None
return str(self._mask_token)
@property
def additional_special_tokens(self) -> List[str]:
"""
`List[str]`: All the additional special tokens you may want to use. Log an error if used while not having been
set.
"""
if self._additional_special_tokens is None:
if self.verbose:
logger.error("Using additional_special_tokens, but it is not set yet.")
return None
return [str(tok) for tok in self._additional_special_tokens]
@bos_token.setter
def bos_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the BOS token")
self._bos_token = value
@eos_token.setter
def eos_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the EOS token")
self._eos_token = value
@unk_token.setter
def unk_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the UNK token")
self._unk_token = value
@sep_token.setter
def sep_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the SEP token")
self._sep_token = value
@pad_token.setter
def pad_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the PAD token")
self._pad_token = value
@cls_token.setter
def cls_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the CLS token")
self._cls_token = value
@mask_token.setter
def mask_token(self, value):
if not isinstance(value, (str, AddedToken)) and value is not None:
raise ValueError("Cannot set a non-string value as the MASK token")
self._mask_token = value
@additional_special_tokens.setter
def additional_special_tokens(self, value):
self._additional_special_tokens = value if value is not None else None
@property
def bos_token_id(self) -> Optional[int]:
"""
`Optional[int]`: Id of the beginning of sentence token in the vocabulary. Returns `None` if the token has not
been set.
"""
if self._bos_token is None:
return None
return self.convert_tokens_to_ids(self.bos_token)
@property
def eos_token_id(self) -> Optional[int]:
"""
`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
set.
"""
if self._eos_token is None:
return None
return self.convert_tokens_to_ids(self.eos_token)
@property
def unk_token_id(self) -> Optional[int]:
"""
`Optional[int]`: Id of the unknown token in the vocabulary. Returns `None` if the token has not been set.
"""
if self._unk_token is None:
return None
return self.convert_tokens_to_ids(self.unk_token)
@property
def sep_token_id(self) -> Optional[int]:
"""
`Optional[int]`: Id of the separation token in the vocabulary, to separate context and query in an input
sequence. Returns `None` if the token has not been set.
"""
if self._sep_token is None:
return None
return self.convert_tokens_to_ids(self.sep_token)
@property
def pad_token_id(self) -> Optional[int]:
"""
`Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
"""
if self._pad_token is None:
return None
return self.convert_tokens_to_ids(self.pad_token)
@property @property
def pad_token_type_id(self) -> int: def pad_token_type_id(self) -> int:
""" """
@ -1246,67 +1054,55 @@ class SpecialTokensMixin:
""" """
return self._pad_token_type_id return self._pad_token_type_id
@property def __setattr__(self, key, value):
def cls_token_id(self) -> Optional[int]: key_without_id = key
""" key_is_special_id = key.endswith("_id") or key.endswith("_ids")
`Optional[int]`: Id of the classification token in the vocabulary, to extract a summary of an input sequence if key_is_special_id:
leveraging self-attention along the full depth of the model. key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4]
Returns `None` if the token has not been set. if self.__dict__.get("_special_tokens_map", None) is not None and any(
""" name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id]
if self._cls_token is None: ):
return None if key_is_special_id:
return self.convert_tokens_to_ids(self.cls_token) if value is not None:
value = (
self.convert_ids_to_tokens(value)
if key != "additional_special_tokens"
else [self.convert_ids_to_tokens(val) for val in value]
)
key = key_without_id
@property if key != "additional_special_tokens" and not isinstance(value, (str, AddedToken)) and value is not None:
def mask_token_id(self) -> Optional[int]: raise ValueError(f"Cannot set a non-string value as the {key}")
""" self._special_tokens_map[key] = value
`Optional[int]`: Id of the mask token in the vocabulary, used when training a model with masked-language else:
modeling. Returns `None` if the token has not been set. super().__setattr__(key, value)
"""
if self._mask_token is None:
return None
return self.convert_tokens_to_ids(self.mask_token)
@property def __getattr__(self, key):
def additional_special_tokens_ids(self) -> List[int]: key_without_id = key
""" key_is_special_id = key.endswith("_id") or key.endswith("_ids")
`List[int]`: Ids of all the additional special tokens in the vocabulary. Log an error if used while not having if key_is_special_id:
been set. key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4]
"""
return self.convert_tokens_to_ids(self.additional_special_tokens)
@bos_token_id.setter if self.__dict__.get("_special_tokens_map", None) is not None and any(
def bos_token_id(self, value): name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id]
self._bos_token = self.convert_ids_to_tokens(value) if value is not None else None ):
_special_tokens_map = self.__dict__["_special_tokens_map"]
if not key_is_special_id:
if _special_tokens_map[key] is None:
if self.verbose:
logger.error(f"Using {key}, but it is not set yet.")
return None
value = _special_tokens_map[key]
return str(value) if key != "additional_special_tokens" else [str(tok) for tok in value]
else:
attr_as_tokens = getattr(self, key_without_id)
return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None
@eos_token_id.setter if key not in self.__dict__:
def eos_token_id(self, value): raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
self._eos_token = self.convert_ids_to_tokens(value) if value is not None else None else:
return super().__getattr__(key)
@unk_token_id.setter
def unk_token_id(self, value):
self._unk_token = self.convert_ids_to_tokens(value) if value is not None else None
@sep_token_id.setter
def sep_token_id(self, value):
self._sep_token = self.convert_ids_to_tokens(value) if value is not None else None
@pad_token_id.setter
def pad_token_id(self, value):
self._pad_token = self.convert_ids_to_tokens(value) if value is not None else None
@cls_token_id.setter
def cls_token_id(self, value):
self._cls_token = self.convert_ids_to_tokens(value) if value is not None else None
@mask_token_id.setter
def mask_token_id(self, value):
self._mask_token = self.convert_ids_to_tokens(value) if value is not None else None
@additional_special_tokens_ids.setter
def additional_special_tokens_ids(self, values):
self._additional_special_tokens = [self.convert_ids_to_tokens(value) for value in values]
@property @property
def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]: def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
@ -1334,7 +1130,7 @@ class SpecialTokensMixin:
""" """
set_attr = {} set_attr = {}
for attr in self.SPECIAL_TOKENS_ATTRIBUTES: for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
attr_value = getattr(self, "_" + attr) attr_value = self._special_tokens_map[attr]
if attr_value: if attr_value:
set_attr[attr] = attr_value set_attr[attr] = attr_value
return set_attr return set_attr
@ -1379,6 +1175,20 @@ class SpecialTokensMixin:
all_ids = self.convert_tokens_to_ids(all_toks) all_ids = self.convert_tokens_to_ids(all_toks)
return all_ids return all_ids
def _set_model_specific_special_tokens(self, special_tokens: List[str]):
"""
Adds new special tokens to the "SPECIAL_TOKENS_ATTRIBUTES" list which will be part
of "self.special_tokens" and saved as a special token in tokenizer's config.
This allows us to dynamically add new model-type specific tokens after initilizing the tokenizer.
For example: if the model tokenizers is multimodal, we can support special image or audio tokens.
"""
self.SPECIAL_TOKENS_ATTRIBUTES = self.SPECIAL_TOKENS_ATTRIBUTES + list(special_tokens.keys())
for key, value in special_tokens.items():
if isinstance(value, (str, AddedToken)):
self._special_tokens_map[key] = value
else:
raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}")
ENCODE_KWARGS_DOCSTRING = r""" ENCODE_KWARGS_DOCSTRING = r"""
add_special_tokens (`bool`, *optional*, defaults to `True`): add_special_tokens (`bool`, *optional*, defaults to `True`):
@ -1633,6 +1443,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
super().__init__(**kwargs) super().__init__(**kwargs)
self.extra_special_tokens = kwargs.pop("extra_special_tokens", {})
self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens)
@property @property
def max_len_single_sentence(self) -> int: def max_len_single_sentence(self) -> int:
""" """
@ -2591,8 +2404,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if hasattr(self, k): if hasattr(self, k):
tokenizer_config[k] = getattr(self, k) tokenizer_config[k] = getattr(self, k)
# Let's make sure we properly save the special tokens. # Let's make sure we properly save the special tokens
tokenizer_config.update(self.special_tokens_map) tokenizer_config.update(self.special_tokens_map)
if "extra_special_tokens" not in tokenizer_config:
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
tokenizer_config.update(self.extra_special_tokens)
if self.chat_template is not None: if self.chat_template is not None:
if isinstance(self.chat_template, dict): if isinstance(self.chat_template, dict):

View File

@ -863,13 +863,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy() special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
special_tokens_list.remove("additional_special_tokens") special_tokens_list.remove("additional_special_tokens")
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. if getattr(self, token) is not None:
if getattr(self, f"_{token}") is not None:
special_token = getattr(self, token) special_token = getattr(self, token)
if special_tokens_map is not None and special_token in special_tokens_map: if special_tokens_map is not None and special_token in special_tokens_map:
special_token = special_tokens_map[special_token] special_token = special_tokens_map[special_token]
special_token_full = getattr(self, f"_{token}") special_token_full = self._special_tokens_map.get(token, None)
if isinstance(special_token_full, AddedToken): if isinstance(special_token_full, AddedToken):
# Create an added token with the same parameters except the content # Create an added token with the same parameters except the content
kwargs[token] = AddedToken( kwargs[token] = AddedToken(

View File

@ -154,7 +154,7 @@ class CamembertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos) tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"): with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
self.assertEqual(tokenizer._eos_token, new_eos) self.assertEqual(tokenizer._special_tokens_map["eos_token"], new_eos)
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values())) self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
with tempfile.TemporaryDirectory() as tmp_dir_2: with tempfile.TemporaryDirectory() as tmp_dir_2:
@ -194,7 +194,7 @@ class CamembertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_fast = self.rust_tokenizer_class.from_pretrained( tokenizer_fast = self.rust_tokenizer_class.from_pretrained(
pretrained_name, eos_token=new_eos, from_slow=True pretrained_name, eos_token=new_eos, from_slow=True
) )
self.assertEqual(tokenizer_fast._eos_token, new_eos) self.assertEqual(tokenizer_fast._special_tokens_map["eos_token"], new_eos)
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values())) self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright # We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"): with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):

View File

@ -1659,7 +1659,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
special_tokens_map = {} special_tokens_map = {}
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is not None: if getattr(tokenizer, token) is not None:
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
special_tokens_map[special_token] = f"{special_token}a" special_tokens_map[special_token] = f"{special_token}a"
@ -1671,7 +1671,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Check the changes # Check the changes
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is None: if getattr(tokenizer, token) is None:
continue continue
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
if special_token in special_tokens_map: if special_token in special_tokens_map:

View File

@ -1537,7 +1537,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
special_tokens_map = {} special_tokens_map = {}
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is not None: if getattr(tokenizer, token) is not None:
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
special_tokens_map[special_token] = f"{special_token}a" special_tokens_map[special_token] = f"{special_token}a"
@ -1549,7 +1549,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Check the changes # Check the changes
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is None: if getattr(tokenizer, token) is None:
continue continue
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
if special_token in special_tokens_map: if special_token in special_tokens_map:

View File

@ -1588,7 +1588,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
special_tokens_map = {} special_tokens_map = {}
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is not None: if getattr(tokenizer, token) is not None:
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
special_tokens_map[special_token] = f"{special_token}a" special_tokens_map[special_token] = f"{special_token}a"
@ -1600,7 +1600,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Check the changes # Check the changes
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is None: if getattr(tokenizer, token) is None:
continue continue
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
if special_token in special_tokens_map: if special_token in special_tokens_map:

View File

@ -385,6 +385,7 @@ class LlamaIntegrationTest(unittest.TestCase):
assert fast == [1, 319, 4559, 1243] assert fast == [1, 319, 4559, 1243]
fast_tokenizer.add_eos_token = True fast_tokenizer.add_eos_token = True
print(fast_tokenizer.add_eos_token)
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True) fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
assert fast == [1, 319, 4559, 1243, 2] assert fast == [1, 319, 4559, 1243, 2]

View File

@ -1435,7 +1435,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
special_tokens_map = {} special_tokens_map = {}
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is not None: if getattr(tokenizer, token) is not None:
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
special_tokens_map[special_token] = f"{special_token}a" special_tokens_map[special_token] = f"{special_token}a"
@ -1447,7 +1447,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Check the changes # Check the changes
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is None: if getattr(tokenizer, token) is None:
continue continue
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
if special_token in special_tokens_map: if special_token in special_tokens_map:

View File

@ -237,7 +237,7 @@ class MoshiTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
special_tokens_map = {} special_tokens_map = {}
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is not None: if getattr(tokenizer, token) is not None:
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
special_tokens_map[special_token] = f"{special_token}a" special_tokens_map[special_token] = f"{special_token}a"
@ -249,7 +249,7 @@ class MoshiTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Check the changes # Check the changes
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is None: if getattr(tokenizer, token) is None:
continue continue
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
if special_token in special_tokens_map: if special_token in special_tokens_map:

View File

@ -185,7 +185,7 @@ class RemBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
) )
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"): with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
self.assertEqual(tokenizer._eos_token, new_eos) self.assertEqual(tokenizer._special_tokens_map["eos_token"], new_eos)
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values())) self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
with tempfile.TemporaryDirectory() as tmp_dir_2: with tempfile.TemporaryDirectory() as tmp_dir_2:
@ -223,7 +223,7 @@ class RemBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with self.subTest("Hub -> Fast: Test loading a fast tokenizer from the hub)"): with self.subTest("Hub -> Fast: Test loading a fast tokenizer from the hub)"):
if self.rust_tokenizer_class is not None: if self.rust_tokenizer_class is not None:
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos) tokenizer_fast = self.rust_tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
self.assertEqual(tokenizer_fast._eos_token, new_eos) self.assertEqual(tokenizer_fast._special_tokens_map["eos_token"], new_eos)
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values())) self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright # We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"): with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):

View File

@ -1538,7 +1538,7 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
special_tokens_map = {} special_tokens_map = {}
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is not None: if getattr(tokenizer, token) is not None:
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
special_tokens_map[special_token] = f"{special_token}a" special_tokens_map[special_token] = f"{special_token}a"
@ -1550,7 +1550,7 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Check the changes # Check the changes
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is None: if getattr(tokenizer, token) is None:
continue continue
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
if special_token in special_tokens_map: if special_token in special_tokens_map:

View File

@ -4156,8 +4156,7 @@ class TokenizerTesterMixin:
special_tokens_list.remove("additional_special_tokens") special_tokens_list.remove("additional_special_tokens")
special_tokens_map = {} special_tokens_map = {}
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. if getattr(tokenizer, token) is not None:
if getattr(tokenizer, f"_{token}") is not None:
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
special_tokens_map[special_token] = f"{special_token}a" special_tokens_map[special_token] = f"{special_token}a"
@ -4169,7 +4168,7 @@ class TokenizerTesterMixin:
# Check the changes # Check the changes
for token in special_tokens_list: for token in special_tokens_list:
# Get the private one to avoid unnecessary warnings. # Get the private one to avoid unnecessary warnings.
if getattr(tokenizer, f"_{token}") is None: if getattr(tokenizer, token) is None:
continue continue
special_token = getattr(tokenizer, token) special_token = getattr(tokenizer, token)
if special_token in special_tokens_map: if special_token in special_tokens_map:
@ -4411,7 +4410,7 @@ class TokenizerTesterMixin:
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos) tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"): with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
self.assertEqual(tokenizer._eos_token, new_eos) self.assertEqual(tokenizer._special_tokens_map["eos_token"], new_eos)
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values())) self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
with tempfile.TemporaryDirectory() as tmp_dir_2: with tempfile.TemporaryDirectory() as tmp_dir_2:
@ -4449,7 +4448,7 @@ class TokenizerTesterMixin:
with self.subTest("Hub -> Fast: Test loading a fast tokenizer from the hub)"): with self.subTest("Hub -> Fast: Test loading a fast tokenizer from the hub)"):
if self.rust_tokenizer_class is not None: if self.rust_tokenizer_class is not None:
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos) tokenizer_fast = self.rust_tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
self.assertEqual(tokenizer_fast._eos_token, new_eos) self.assertEqual(tokenizer_fast._special_tokens_map["eos_token"], new_eos)
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values())) self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright # We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"): with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):

View File

@ -28,6 +28,7 @@ from transformers import (
BatchEncoding, BatchEncoding,
BertTokenizer, BertTokenizer,
BertTokenizerFast, BertTokenizerFast,
LlamaTokenizerFast,
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
TensorType, TensorType,
@ -280,6 +281,54 @@ class TokenizerUtilsTest(unittest.TestCase):
self.assertEqual(decoded_flat, "##") self.assertEqual(decoded_flat, "##")
self.assertEqual(decoded_list, "##") self.assertEqual(decoded_list, "##")
def test_extra_special_tokens_multimodal(self):
special_tokens_list = [
"bos_token",
"eos_token",
"unk_token",
"sep_token",
"pad_token",
"cls_token",
"mask_token",
"additional_special_tokens",
]
llama_tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
llama_tokenizer.extra_special_tokens = {
"boi_token": "<image_start>",
"eoi_token": "<image_end>",
"image_token": "<image>",
}
self.assertListEqual(llama_tokenizer.SPECIAL_TOKENS_ATTRIBUTES, special_tokens_list)
with tempfile.TemporaryDirectory() as tmpdirname:
llama_tokenizer.save_pretrained(tmpdirname)
# load back and check we have extra special tokens set
loaded_tokenizer = LlamaTokenizerFast.from_pretrained(tmpdirname)
multimodal_special_tokens_list = special_tokens_list + ["boi_token", "eoi_token", "image_token"]
self.assertListEqual(loaded_tokenizer.SPECIAL_TOKENS_ATTRIBUTES, multimodal_special_tokens_list)
# We set an image_token_id before, so we can get an "image_token" as str that matches the id
self.assertTrue(loaded_tokenizer.image_token == "<image>")
self.assertTrue(loaded_tokenizer.image_token_id == loaded_tokenizer.convert_tokens_to_ids("<image>"))
# save one more time and make sure the image token can get loaded back
with tempfile.TemporaryDirectory() as tmpdirname:
loaded_tokenizer.save_pretrained(tmpdirname)
loaded_tokenizer_with_extra_tokens = LlamaTokenizerFast.from_pretrained(tmpdirname)
self.assertTrue(loaded_tokenizer_with_extra_tokens.image_token == "<image>")
# test that we can also indicate extra tokens during load time
extra_special_tokens = {
"boi_token": "<image_start>",
"eoi_token": "<image_end>",
"image_token": "<image>",
}
tokenizer = LlamaTokenizerFast.from_pretrained(
"huggyllama/llama-7b", extra_special_tokens=extra_special_tokens
)
self.assertTrue(tokenizer.image_token == "<image>")
self.assertTrue(tokenizer.image_token_id == loaded_tokenizer.convert_tokens_to_ids("<image>"))
@require_tokenizers @require_tokenizers
def test_decoding_skip_special_tokens(self): def test_decoding_skip_special_tokens(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast]: for tokenizer_class in [BertTokenizer, BertTokenizerFast]:

View File

@ -299,7 +299,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 16))) self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
tokenizer._pad_token = None tokenizer.pad_token = None
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Expect error due to padding token missing # Expect error due to padding token missing
@ -978,7 +978,7 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 16]) self.assertEqual(batch["input_ids"].shape.as_list(), [2, 16])
self.assertEqual(batch["labels"].shape.as_list(), [2, 16]) self.assertEqual(batch["labels"].shape.as_list(), [2, 16])
tokenizer._pad_token = None tokenizer.pad_token = None
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf") data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Expect error due to padding token missing # Expect error due to padding token missing
@ -1673,7 +1673,7 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["input_ids"].shape, (2, 16)) self.assertEqual(batch["input_ids"].shape, (2, 16))
self.assertEqual(batch["labels"].shape, (2, 16)) self.assertEqual(batch["labels"].shape, (2, 16))
tokenizer._pad_token = None tokenizer.pad_token = None
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="np") data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="np")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# Expect error due to padding token missing # Expect error due to padding token missing