mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
VLM: special multimodal Tokenizer (#34461)
* kinda works * update * add tests * update * use special tokens in processors * typo * fix copies * fix * fix moshi after rebase * update * fix tests * update * Update docs/source/en/main_classes/tokenizer.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update docs * test for load time adding tokens * fix some more tests which are now fetched better * one more fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
ef976a7e18
commit
187439c3fa
@ -51,6 +51,25 @@ token space (e.g., getting the index of the token comprising a given character o
|
|||||||
to a given token).
|
to a given token).
|
||||||
|
|
||||||
|
|
||||||
|
# Multimodal Tokenizer
|
||||||
|
|
||||||
|
Apart from that each tokenizer can be a "multimodal" tokenizer which means that the tokenizer will hold all relevant special tokens
|
||||||
|
as part of tokenizer attributes for easier access. For example, if the tokenizer is loaded from a vision-language model like LLaVA, you will
|
||||||
|
be able to access `tokenizer.image_token_id` to obtain the special image token used as a placeholder.
|
||||||
|
|
||||||
|
To enable extra special tokens for any type of tokenizer, you have to add the following lines and save the tokenizer. Extra special tokens do not
|
||||||
|
have to be modality related and can ne anything that the model often needs access to. In the below code, tokenizer at `output_dir` will have direct access
|
||||||
|
to three more special tokens.
|
||||||
|
|
||||||
|
```python
|
||||||
|
vision_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"llava-hf/llava-1.5-7b-hf",
|
||||||
|
extra_special_tokens={"image_token": "<image>", "boi_token": "<image_start>", "eoi_token": "<image_end>"}
|
||||||
|
)
|
||||||
|
print(vision_tokenizer.image_token, vision_tokenizer.image_token_id)
|
||||||
|
("<image>", 32000)
|
||||||
|
```
|
||||||
|
|
||||||
## PreTrainedTokenizer
|
## PreTrainedTokenizer
|
||||||
|
|
||||||
[[autodoc]] PreTrainedTokenizer
|
[[autodoc]] PreTrainedTokenizer
|
||||||
|
@ -443,7 +443,7 @@ def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int]
|
|||||||
return torch.stack(examples, dim=0)
|
return torch.stack(examples, dim=0)
|
||||||
|
|
||||||
# If yes, check if we have a `pad_token`.
|
# If yes, check if we have a `pad_token`.
|
||||||
if tokenizer._pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You are attempting to pad samples but the tokenizer you are using"
|
"You are attempting to pad samples but the tokenizer you are using"
|
||||||
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
||||||
@ -477,7 +477,7 @@ def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = N
|
|||||||
return tf.stack(examples, axis=0)
|
return tf.stack(examples, axis=0)
|
||||||
|
|
||||||
# If yes, check if we have a `pad_token`.
|
# If yes, check if we have a `pad_token`.
|
||||||
if tokenizer._pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You are attempting to pad samples but the tokenizer you are using"
|
"You are attempting to pad samples but the tokenizer you are using"
|
||||||
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
||||||
@ -513,7 +513,7 @@ def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int]
|
|||||||
return np.stack(examples, axis=0)
|
return np.stack(examples, axis=0)
|
||||||
|
|
||||||
# If yes, check if we have a `pad_token`.
|
# If yes, check if we have a `pad_token`.
|
||||||
if tokenizer._pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You are attempting to pad samples but the tokenizer you are using"
|
"You are attempting to pad samples but the tokenizer you are using"
|
||||||
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
||||||
@ -1090,7 +1090,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||||
]
|
]
|
||||||
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
||||||
if self.tokenizer._pad_token is not None:
|
if self.tokenizer.pad_token is not None:
|
||||||
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||||
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
||||||
|
|
||||||
@ -1131,7 +1131,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
|
||||||
]
|
]
|
||||||
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
|
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
|
||||||
if self.tokenizer._pad_token is not None:
|
if self.tokenizer.pad_token is not None:
|
||||||
padding_mask = inputs == self.tokenizer.pad_token_id
|
padding_mask = inputs == self.tokenizer.pad_token_id
|
||||||
masked_indices = masked_indices & ~padding_mask
|
masked_indices = masked_indices & ~padding_mask
|
||||||
|
|
||||||
@ -1170,7 +1170,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|||||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||||
]
|
]
|
||||||
masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
|
masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
|
||||||
if self.tokenizer._pad_token is not None:
|
if self.tokenizer.pad_token is not None:
|
||||||
padding_mask = labels == self.tokenizer.pad_token_id
|
padding_mask = labels == self.tokenizer.pad_token_id
|
||||||
masked_indices[padding_mask] = 0
|
masked_indices[padding_mask] = 0
|
||||||
|
|
||||||
@ -1251,13 +1251,13 @@ class DataCollatorForSOP(DataCollatorForLanguageModeling):
|
|||||||
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||||
]
|
]
|
||||||
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
||||||
if self.tokenizer._pad_token is not None:
|
if self.tokenizer.pad_token is not None:
|
||||||
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||||
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
||||||
masked_indices = torch.bernoulli(probability_matrix).bool()
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
||||||
# probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
|
# probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
|
||||||
attention_mask = (~masked_indices).float()
|
attention_mask = (~masked_indices).float()
|
||||||
if self.tokenizer._pad_token is not None:
|
if self.tokenizer.pad_token is not None:
|
||||||
attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||||
attention_mask.masked_fill_(attention_padding_mask, value=1.0)
|
attention_mask.masked_fill_(attention_padding_mask, value=1.0)
|
||||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
|
||||||
@ -1367,7 +1367,7 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
|
|||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
)
|
)
|
||||||
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
|
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
|
||||||
if self.tokenizer._pad_token is not None:
|
if self.tokenizer.pad_token is not None:
|
||||||
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||||
masked_indices.masked_fill_(padding_mask, value=0.0)
|
masked_indices.masked_fill_(padding_mask, value=0.0)
|
||||||
|
|
||||||
@ -1471,7 +1471,7 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
|
|||||||
)
|
)
|
||||||
special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
|
special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
|
||||||
masked_indices = masked_indices & ~special_tokens_mask
|
masked_indices = masked_indices & ~special_tokens_mask
|
||||||
if self.tokenizer._pad_token is not None:
|
if self.tokenizer.pad_token is not None:
|
||||||
padding_mask = labels == self.tokenizer.pad_token_id
|
padding_mask = labels == self.tokenizer.pad_token_id
|
||||||
masked_indices = masked_indices & ~padding_mask
|
masked_indices = masked_indices & ~padding_mask
|
||||||
|
|
||||||
@ -1571,7 +1571,7 @@ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
|
|||||||
dtype=bool,
|
dtype=bool,
|
||||||
)
|
)
|
||||||
masked_indices[special_tokens_mask] = 0
|
masked_indices[special_tokens_mask] = 0
|
||||||
if self.tokenizer._pad_token is not None:
|
if self.tokenizer.pad_token is not None:
|
||||||
padding_mask = labels == self.tokenizer.pad_token_id
|
padding_mask = labels == self.tokenizer.pad_token_id
|
||||||
masked_indices[padding_mask] = 0.0
|
masked_indices[padding_mask] = 0.0
|
||||||
|
|
||||||
|
@ -74,8 +74,11 @@ class Blip2Processor(ProcessorMixin):
|
|||||||
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
|
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
|
||||||
tokenizer.return_token_type_ids = False
|
tokenizer.return_token_type_ids = False
|
||||||
self.current_processor = image_processor
|
self.current_processor = image_processor
|
||||||
self.image_token = AddedToken("<image>", normalized=False, special=True)
|
if not hasattr(tokenizer, "image_token"):
|
||||||
tokenizer.add_tokens([self.image_token], special_tokens=True)
|
self.image_token = AddedToken("<image>", normalized=False, special=True)
|
||||||
|
tokenizer.add_tokens([self.image_token], special_tokens=True)
|
||||||
|
else:
|
||||||
|
self.image_token = tokenizer.image_token
|
||||||
self.num_query_tokens = num_query_tokens
|
self.num_query_tokens = num_query_tokens
|
||||||
|
|
||||||
super().__init__(image_processor, tokenizer)
|
super().__init__(image_processor, tokenizer)
|
||||||
|
@ -66,9 +66,12 @@ class ChameleonProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
|
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
|
||||||
self.image_seq_length = image_seq_length
|
self.image_seq_length = image_seq_length
|
||||||
self.image_token = image_token
|
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||||
self.image_start_token = "<racm3:break>" # fixed tokens for start and end, so can hardcode
|
self.image_start_token = (
|
||||||
self.image_end_token = "<eoss>"
|
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
|
||||||
|
) # fixed tokens for start and end, so can hardcode
|
||||||
|
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
|
||||||
|
|
||||||
super().__init__(image_processor, tokenizer)
|
super().__init__(image_processor, tokenizer)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -138,7 +138,7 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
self.__dict__ = d
|
self.__dict__.update(d)
|
||||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
||||||
|
|
||||||
|
@ -219,7 +219,11 @@ class IdeficsProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
super().__init__(image_processor, tokenizer)
|
super().__init__(image_processor, tokenizer)
|
||||||
self.current_processor = self.image_processor
|
self.current_processor = self.image_processor
|
||||||
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
self.image_token_id = (
|
||||||
|
tokenizer.image_token_id
|
||||||
|
if hasattr(tokenizer, "image_token")
|
||||||
|
else tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||||
|
)
|
||||||
|
|
||||||
self.default_image_dims = (
|
self.default_image_dims = (
|
||||||
self.image_processor.image_num_channels,
|
self.image_processor.image_num_channels,
|
||||||
|
@ -95,15 +95,18 @@ class Idefics2Processor(ProcessorMixin):
|
|||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
raise ValueError("You need to specify a `tokenizer`.")
|
raise ValueError("You need to specify a `tokenizer`.")
|
||||||
|
|
||||||
self.fake_image_token = AddedToken("<fake_token_around_image>", normalized=False, special=True)
|
if not hasattr(tokenizer, "image_token"):
|
||||||
self.image_token = AddedToken("<image>", normalized=False, special=True)
|
self.fake_image_token = AddedToken("<fake_token_around_image>", normalized=False, special=True)
|
||||||
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True)
|
self.image_token = AddedToken("<image>", normalized=False, special=True)
|
||||||
self.image_seq_len = image_seq_len
|
tokens_to_add = {"additional_special_tokens": [self.fake_image_token, self.image_token]}
|
||||||
|
tokenizer.add_special_tokens(tokens_to_add)
|
||||||
|
else:
|
||||||
|
self.fake_image_token = tokenizer.image_boundary_token
|
||||||
|
self.image_token = tokenizer.image_token
|
||||||
|
|
||||||
tokens_to_add = {
|
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True)
|
||||||
"additional_special_tokens": [self.fake_image_token, self.image_token, self.end_of_utterance_token]
|
tokenizer.add_special_tokens({"additional_special_tokens": [self.end_of_utterance_token]})
|
||||||
}
|
self.image_seq_len = image_seq_len
|
||||||
tokenizer.add_special_tokens(tokens_to_add)
|
|
||||||
|
|
||||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
|
@ -78,8 +78,11 @@ class InstructBlipProcessor(ProcessorMixin):
|
|||||||
qformer_tokenizer_class = "AutoTokenizer"
|
qformer_tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
|
def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
|
||||||
self.image_token = AddedToken("<image>", normalized=False, special=True)
|
if not hasattr(tokenizer, "image_token"):
|
||||||
tokenizer.add_tokens([self.image_token], special_tokens=True)
|
self.image_token = AddedToken("<image>", normalized=False, special=True)
|
||||||
|
tokenizer.add_tokens([self.image_token], special_tokens=True)
|
||||||
|
else:
|
||||||
|
self.image_token = tokenizer.image_token
|
||||||
self.num_query_tokens = num_query_tokens
|
self.num_query_tokens = num_query_tokens
|
||||||
super().__init__(image_processor, tokenizer, qformer_tokenizer)
|
super().__init__(image_processor, tokenizer, qformer_tokenizer)
|
||||||
|
|
||||||
|
@ -63,8 +63,11 @@ class InstructBlipVideoProcessor(ProcessorMixin):
|
|||||||
qformer_tokenizer_class = "AutoTokenizer"
|
qformer_tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
|
def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
|
||||||
self.video_token = AddedToken("<video>", normalized=False, special=True)
|
if not hasattr(tokenizer, "video_token"):
|
||||||
tokenizer.add_tokens([self.video_token], special_tokens=True)
|
self.video_token = AddedToken("<video>", normalized=False, special=True)
|
||||||
|
tokenizer.add_tokens([self.video_token], special_tokens=True)
|
||||||
|
else:
|
||||||
|
self.video_token = tokenizer.video_token
|
||||||
self.num_query_tokens = num_query_tokens
|
self.num_query_tokens = num_query_tokens
|
||||||
super().__init__(image_processor, tokenizer, qformer_tokenizer)
|
super().__init__(image_processor, tokenizer, qformer_tokenizer)
|
||||||
|
|
||||||
|
@ -297,7 +297,7 @@ class LayoutXLMTokenizer(PreTrainedTokenizer):
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
self.__dict__ = d
|
self.__dict__.update(d)
|
||||||
|
|
||||||
# for backward compatibility
|
# for backward compatibility
|
||||||
if not hasattr(self, "sp_model_kwargs"):
|
if not hasattr(self, "sp_model_kwargs"):
|
||||||
|
@ -214,7 +214,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
self.__dict__ = d
|
self.__dict__.update(d)
|
||||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ class LlavaProcessor(ProcessorMixin):
|
|||||||
):
|
):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||||
self.image_token = image_token
|
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -80,7 +80,7 @@ class LlavaNextProcessor(ProcessorMixin):
|
|||||||
):
|
):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||||
self.image_token = image_token
|
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -82,8 +82,8 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
|||||||
):
|
):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||||
self.image_token = image_token
|
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||||
self.video_token = video_token
|
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||||
super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -96,8 +96,8 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
|||||||
):
|
):
|
||||||
self.num_image_tokens = num_image_tokens
|
self.num_image_tokens = num_image_tokens
|
||||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||||
self.image_token = image_token
|
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||||
self.video_token = video_token
|
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -213,8 +213,13 @@ class MllamaProcessor(ProcessorMixin):
|
|||||||
tokenizer_class = "PreTrainedTokenizerFast"
|
tokenizer_class = "PreTrainedTokenizerFast"
|
||||||
|
|
||||||
def __init__(self, image_processor, tokenizer):
|
def __init__(self, image_processor, tokenizer):
|
||||||
self.image_token = "<|image|>"
|
if not hasattr(tokenizer, "image_token"):
|
||||||
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
self.image_token = "<|image|>"
|
||||||
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||||
|
else:
|
||||||
|
self.image_token = tokenizer.image_token
|
||||||
|
self.image_token_id = tokenizer.image_token_id
|
||||||
|
|
||||||
self.python_token = "<|python_tag|>"
|
self.python_token = "<|python_tag|>"
|
||||||
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
|
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
|
||||||
self.bos_token = tokenizer.bos_token
|
self.bos_token = tokenizer.bos_token
|
||||||
|
@ -160,11 +160,15 @@ class PaliGemmaProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
self.image_seq_length = image_processor.image_seq_length
|
self.image_seq_length = image_processor.image_seq_length
|
||||||
|
|
||||||
image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
|
if not hasattr(tokenizer, "image_token"):
|
||||||
tokens_to_add = {"additional_special_tokens": [image_token]}
|
image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
|
||||||
tokenizer.add_special_tokens(tokens_to_add)
|
tokens_to_add = {"additional_special_tokens": [image_token]}
|
||||||
|
tokenizer.add_special_tokens(tokens_to_add)
|
||||||
|
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||||
|
else:
|
||||||
|
self.image_token_id = tokenizer.image_token_id
|
||||||
|
|
||||||
tokenizer.add_tokens(EXTRA_TOKENS)
|
tokenizer.add_tokens(EXTRA_TOKENS)
|
||||||
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
|
||||||
tokenizer.add_bos_token = False
|
tokenizer.add_bos_token = False
|
||||||
tokenizer.add_eos_token = False
|
tokenizer.add_eos_token = False
|
||||||
|
|
||||||
|
@ -61,6 +61,8 @@ class Qwen2VLProcessor(ProcessorMixin):
|
|||||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||||
|
|
||||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||||
|
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||||
|
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -132,23 +134,23 @@ class Qwen2VLProcessor(ProcessorMixin):
|
|||||||
merge_length = self.image_processor.merge_size**2
|
merge_length = self.image_processor.merge_size**2
|
||||||
index = 0
|
index = 0
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
while "<|image_pad|>" in text[i]:
|
while self.image_token in text[i]:
|
||||||
text[i] = text[i].replace(
|
text[i] = text[i].replace(
|
||||||
"<|image_pad|>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
|
self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
|
||||||
)
|
)
|
||||||
index += 1
|
index += 1
|
||||||
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||||
|
|
||||||
if video_grid_thw is not None:
|
if video_grid_thw is not None:
|
||||||
merge_length = self.image_processor.merge_size**2
|
merge_length = self.image_processor.merge_size**2
|
||||||
index = 0
|
index = 0
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
while "<|video_pad|>" in text[i]:
|
while self.video_token in text[i]:
|
||||||
text[i] = text[i].replace(
|
text[i] = text[i].replace(
|
||||||
"<|video_pad|>", "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
|
self.video_token, "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
|
||||||
)
|
)
|
||||||
index += 1
|
index += 1
|
||||||
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
|
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
||||||
|
|
||||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
|
|
||||||
|
@ -412,7 +412,7 @@ class UdopTokenizer(PreTrainedTokenizer):
|
|||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
self.__dict__ = d
|
self.__dict__.update(d)
|
||||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.Load(self.vocab_file)
|
self.sp_model.Load(self.vocab_file)
|
||||||
|
|
||||||
|
@ -71,8 +71,8 @@ class VideoLlavaProcessor(ProcessorMixin):
|
|||||||
):
|
):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||||
self.image_token = image_token
|
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||||
self.video_token = video_token
|
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -577,7 +577,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
token_index = current_vocab[token.content]
|
token_index = current_vocab[token.content]
|
||||||
|
|
||||||
if token.special and str(token) not in self.all_special_tokens:
|
if token.special and str(token) not in self.all_special_tokens:
|
||||||
self._additional_special_tokens.append(token)
|
self._special_tokens_map["additional_special_tokens"].append(token)
|
||||||
# the setter automatically updates the reverse map
|
# the setter automatically updates the reverse map
|
||||||
self._added_tokens_decoder[token_index] = token
|
self._added_tokens_decoder[token_index] = token
|
||||||
self._added_tokens_encoder[token.content] = token_index
|
self._added_tokens_encoder[token.content] = token_index
|
||||||
|
@ -861,16 +861,10 @@ class SpecialTokensMixin:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, verbose=False, **kwargs):
|
def __init__(self, verbose=False, **kwargs):
|
||||||
self._bos_token = None
|
|
||||||
self._eos_token = None
|
|
||||||
self._unk_token = None
|
|
||||||
self._sep_token = None
|
|
||||||
self._pad_token = None
|
|
||||||
self._cls_token = None
|
|
||||||
self._mask_token = None
|
|
||||||
self._pad_token_type_id = 0
|
self._pad_token_type_id = 0
|
||||||
self._additional_special_tokens = []
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
self._special_tokens_map = {attr: None for attr in self.SPECIAL_TOKENS_ATTRIBUTES}
|
||||||
|
self._special_tokens_map["additional_special_tokens"] = [] # for BC where it defaults to empty list
|
||||||
|
|
||||||
# We directly set the hidden value to allow initialization with special tokens
|
# We directly set the hidden value to allow initialization with special tokens
|
||||||
# which are not yet in the vocabulary. Necessary for serialization/de-serialization
|
# which are not yet in the vocabulary. Necessary for serialization/de-serialization
|
||||||
@ -932,7 +926,7 @@ class SpecialTokensMixin:
|
|||||||
assign the index of the `unk_token` to them).
|
assign the index of the `unk_token` to them).
|
||||||
replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
|
replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
|
||||||
If `True`, the existing list of additional special tokens will be replaced by the list provided in
|
If `True`, the existing list of additional special tokens will be replaced by the list provided in
|
||||||
`special_tokens_dict`. Otherwise, `self._additional_special_tokens` is just extended. In the former
|
`special_tokens_dict`. Otherwise, `self._special_tokens_map["additional_special_tokens"]` is just extended. In the former
|
||||||
case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
|
case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
|
||||||
as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
|
as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
|
||||||
`added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
|
`added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
|
||||||
@ -983,7 +977,7 @@ class SpecialTokensMixin:
|
|||||||
if replace_additional_special_tokens and len(to_add) > 0:
|
if replace_additional_special_tokens and len(to_add) > 0:
|
||||||
setattr(self, key, list(to_add))
|
setattr(self, key, list(to_add))
|
||||||
else:
|
else:
|
||||||
self._additional_special_tokens.extend(to_add)
|
self._special_tokens_map["additional_special_tokens"].extend(to_add)
|
||||||
added_tokens += to_add
|
added_tokens += to_add
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -1053,192 +1047,6 @@ class SpecialTokensMixin:
|
|||||||
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
|
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
|
||||||
def bos_token(self) -> str:
|
|
||||||
"""
|
|
||||||
`str`: Beginning of sentence token. Log an error if used while not having been set.
|
|
||||||
"""
|
|
||||||
if self._bos_token is None:
|
|
||||||
if self.verbose:
|
|
||||||
logger.error("Using bos_token, but it is not set yet.")
|
|
||||||
return None
|
|
||||||
return str(self._bos_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos_token(self) -> str:
|
|
||||||
"""
|
|
||||||
`str`: End of sentence token. Log an error if used while not having been set.
|
|
||||||
"""
|
|
||||||
if self._eos_token is None:
|
|
||||||
if self.verbose:
|
|
||||||
logger.error("Using eos_token, but it is not set yet.")
|
|
||||||
return None
|
|
||||||
return str(self._eos_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def unk_token(self) -> str:
|
|
||||||
"""
|
|
||||||
`str`: Unknown token. Log an error if used while not having been set.
|
|
||||||
"""
|
|
||||||
if self._unk_token is None:
|
|
||||||
if self.verbose:
|
|
||||||
logger.error("Using unk_token, but it is not set yet.")
|
|
||||||
return None
|
|
||||||
return str(self._unk_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sep_token(self) -> str:
|
|
||||||
"""
|
|
||||||
`str`: Separation token, to separate context and query in an input sequence. Log an error if used while not
|
|
||||||
having been set.
|
|
||||||
"""
|
|
||||||
if self._sep_token is None:
|
|
||||||
if self.verbose:
|
|
||||||
logger.error("Using sep_token, but it is not set yet.")
|
|
||||||
return None
|
|
||||||
return str(self._sep_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pad_token(self) -> str:
|
|
||||||
"""
|
|
||||||
`str`: Padding token. Log an error if used while not having been set.
|
|
||||||
"""
|
|
||||||
if self._pad_token is None:
|
|
||||||
if self.verbose:
|
|
||||||
logger.error("Using pad_token, but it is not set yet.")
|
|
||||||
return None
|
|
||||||
return str(self._pad_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def cls_token(self) -> str:
|
|
||||||
"""
|
|
||||||
`str`: Classification token, to extract a summary of an input sequence leveraging self-attention along the full
|
|
||||||
depth of the model. Log an error if used while not having been set.
|
|
||||||
"""
|
|
||||||
if self._cls_token is None:
|
|
||||||
if self.verbose:
|
|
||||||
logger.error("Using cls_token, but it is not set yet.")
|
|
||||||
return None
|
|
||||||
return str(self._cls_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mask_token(self) -> str:
|
|
||||||
"""
|
|
||||||
`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
|
|
||||||
having been set.
|
|
||||||
"""
|
|
||||||
if self._mask_token is None:
|
|
||||||
if self.verbose:
|
|
||||||
logger.error("Using mask_token, but it is not set yet.")
|
|
||||||
return None
|
|
||||||
return str(self._mask_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def additional_special_tokens(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
`List[str]`: All the additional special tokens you may want to use. Log an error if used while not having been
|
|
||||||
set.
|
|
||||||
"""
|
|
||||||
if self._additional_special_tokens is None:
|
|
||||||
if self.verbose:
|
|
||||||
logger.error("Using additional_special_tokens, but it is not set yet.")
|
|
||||||
return None
|
|
||||||
return [str(tok) for tok in self._additional_special_tokens]
|
|
||||||
|
|
||||||
@bos_token.setter
|
|
||||||
def bos_token(self, value):
|
|
||||||
if not isinstance(value, (str, AddedToken)) and value is not None:
|
|
||||||
raise ValueError("Cannot set a non-string value as the BOS token")
|
|
||||||
self._bos_token = value
|
|
||||||
|
|
||||||
@eos_token.setter
|
|
||||||
def eos_token(self, value):
|
|
||||||
if not isinstance(value, (str, AddedToken)) and value is not None:
|
|
||||||
raise ValueError("Cannot set a non-string value as the EOS token")
|
|
||||||
self._eos_token = value
|
|
||||||
|
|
||||||
@unk_token.setter
|
|
||||||
def unk_token(self, value):
|
|
||||||
if not isinstance(value, (str, AddedToken)) and value is not None:
|
|
||||||
raise ValueError("Cannot set a non-string value as the UNK token")
|
|
||||||
self._unk_token = value
|
|
||||||
|
|
||||||
@sep_token.setter
|
|
||||||
def sep_token(self, value):
|
|
||||||
if not isinstance(value, (str, AddedToken)) and value is not None:
|
|
||||||
raise ValueError("Cannot set a non-string value as the SEP token")
|
|
||||||
self._sep_token = value
|
|
||||||
|
|
||||||
@pad_token.setter
|
|
||||||
def pad_token(self, value):
|
|
||||||
if not isinstance(value, (str, AddedToken)) and value is not None:
|
|
||||||
raise ValueError("Cannot set a non-string value as the PAD token")
|
|
||||||
self._pad_token = value
|
|
||||||
|
|
||||||
@cls_token.setter
|
|
||||||
def cls_token(self, value):
|
|
||||||
if not isinstance(value, (str, AddedToken)) and value is not None:
|
|
||||||
raise ValueError("Cannot set a non-string value as the CLS token")
|
|
||||||
self._cls_token = value
|
|
||||||
|
|
||||||
@mask_token.setter
|
|
||||||
def mask_token(self, value):
|
|
||||||
if not isinstance(value, (str, AddedToken)) and value is not None:
|
|
||||||
raise ValueError("Cannot set a non-string value as the MASK token")
|
|
||||||
self._mask_token = value
|
|
||||||
|
|
||||||
@additional_special_tokens.setter
|
|
||||||
def additional_special_tokens(self, value):
|
|
||||||
self._additional_special_tokens = value if value is not None else None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bos_token_id(self) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
`Optional[int]`: Id of the beginning of sentence token in the vocabulary. Returns `None` if the token has not
|
|
||||||
been set.
|
|
||||||
"""
|
|
||||||
if self._bos_token is None:
|
|
||||||
return None
|
|
||||||
return self.convert_tokens_to_ids(self.bos_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def eos_token_id(self) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
`Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been
|
|
||||||
set.
|
|
||||||
"""
|
|
||||||
if self._eos_token is None:
|
|
||||||
return None
|
|
||||||
return self.convert_tokens_to_ids(self.eos_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def unk_token_id(self) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
`Optional[int]`: Id of the unknown token in the vocabulary. Returns `None` if the token has not been set.
|
|
||||||
"""
|
|
||||||
if self._unk_token is None:
|
|
||||||
return None
|
|
||||||
return self.convert_tokens_to_ids(self.unk_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sep_token_id(self) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
`Optional[int]`: Id of the separation token in the vocabulary, to separate context and query in an input
|
|
||||||
sequence. Returns `None` if the token has not been set.
|
|
||||||
"""
|
|
||||||
if self._sep_token is None:
|
|
||||||
return None
|
|
||||||
return self.convert_tokens_to_ids(self.sep_token)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pad_token_id(self) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
`Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set.
|
|
||||||
"""
|
|
||||||
if self._pad_token is None:
|
|
||||||
return None
|
|
||||||
return self.convert_tokens_to_ids(self.pad_token)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pad_token_type_id(self) -> int:
|
def pad_token_type_id(self) -> int:
|
||||||
"""
|
"""
|
||||||
@ -1246,67 +1054,55 @@ class SpecialTokensMixin:
|
|||||||
"""
|
"""
|
||||||
return self._pad_token_type_id
|
return self._pad_token_type_id
|
||||||
|
|
||||||
@property
|
def __setattr__(self, key, value):
|
||||||
def cls_token_id(self) -> Optional[int]:
|
key_without_id = key
|
||||||
"""
|
key_is_special_id = key.endswith("_id") or key.endswith("_ids")
|
||||||
`Optional[int]`: Id of the classification token in the vocabulary, to extract a summary of an input sequence
|
if key_is_special_id:
|
||||||
leveraging self-attention along the full depth of the model.
|
key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4]
|
||||||
|
|
||||||
Returns `None` if the token has not been set.
|
if self.__dict__.get("_special_tokens_map", None) is not None and any(
|
||||||
"""
|
name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id]
|
||||||
if self._cls_token is None:
|
):
|
||||||
return None
|
if key_is_special_id:
|
||||||
return self.convert_tokens_to_ids(self.cls_token)
|
if value is not None:
|
||||||
|
value = (
|
||||||
|
self.convert_ids_to_tokens(value)
|
||||||
|
if key != "additional_special_tokens"
|
||||||
|
else [self.convert_ids_to_tokens(val) for val in value]
|
||||||
|
)
|
||||||
|
key = key_without_id
|
||||||
|
|
||||||
@property
|
if key != "additional_special_tokens" and not isinstance(value, (str, AddedToken)) and value is not None:
|
||||||
def mask_token_id(self) -> Optional[int]:
|
raise ValueError(f"Cannot set a non-string value as the {key}")
|
||||||
"""
|
self._special_tokens_map[key] = value
|
||||||
`Optional[int]`: Id of the mask token in the vocabulary, used when training a model with masked-language
|
else:
|
||||||
modeling. Returns `None` if the token has not been set.
|
super().__setattr__(key, value)
|
||||||
"""
|
|
||||||
if self._mask_token is None:
|
|
||||||
return None
|
|
||||||
return self.convert_tokens_to_ids(self.mask_token)
|
|
||||||
|
|
||||||
@property
|
def __getattr__(self, key):
|
||||||
def additional_special_tokens_ids(self) -> List[int]:
|
key_without_id = key
|
||||||
"""
|
key_is_special_id = key.endswith("_id") or key.endswith("_ids")
|
||||||
`List[int]`: Ids of all the additional special tokens in the vocabulary. Log an error if used while not having
|
if key_is_special_id:
|
||||||
been set.
|
key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4]
|
||||||
"""
|
|
||||||
return self.convert_tokens_to_ids(self.additional_special_tokens)
|
|
||||||
|
|
||||||
@bos_token_id.setter
|
if self.__dict__.get("_special_tokens_map", None) is not None and any(
|
||||||
def bos_token_id(self, value):
|
name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id]
|
||||||
self._bos_token = self.convert_ids_to_tokens(value) if value is not None else None
|
):
|
||||||
|
_special_tokens_map = self.__dict__["_special_tokens_map"]
|
||||||
|
if not key_is_special_id:
|
||||||
|
if _special_tokens_map[key] is None:
|
||||||
|
if self.verbose:
|
||||||
|
logger.error(f"Using {key}, but it is not set yet.")
|
||||||
|
return None
|
||||||
|
value = _special_tokens_map[key]
|
||||||
|
return str(value) if key != "additional_special_tokens" else [str(tok) for tok in value]
|
||||||
|
else:
|
||||||
|
attr_as_tokens = getattr(self, key_without_id)
|
||||||
|
return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None
|
||||||
|
|
||||||
@eos_token_id.setter
|
if key not in self.__dict__:
|
||||||
def eos_token_id(self, value):
|
raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
|
||||||
self._eos_token = self.convert_ids_to_tokens(value) if value is not None else None
|
else:
|
||||||
|
return super().__getattr__(key)
|
||||||
@unk_token_id.setter
|
|
||||||
def unk_token_id(self, value):
|
|
||||||
self._unk_token = self.convert_ids_to_tokens(value) if value is not None else None
|
|
||||||
|
|
||||||
@sep_token_id.setter
|
|
||||||
def sep_token_id(self, value):
|
|
||||||
self._sep_token = self.convert_ids_to_tokens(value) if value is not None else None
|
|
||||||
|
|
||||||
@pad_token_id.setter
|
|
||||||
def pad_token_id(self, value):
|
|
||||||
self._pad_token = self.convert_ids_to_tokens(value) if value is not None else None
|
|
||||||
|
|
||||||
@cls_token_id.setter
|
|
||||||
def cls_token_id(self, value):
|
|
||||||
self._cls_token = self.convert_ids_to_tokens(value) if value is not None else None
|
|
||||||
|
|
||||||
@mask_token_id.setter
|
|
||||||
def mask_token_id(self, value):
|
|
||||||
self._mask_token = self.convert_ids_to_tokens(value) if value is not None else None
|
|
||||||
|
|
||||||
@additional_special_tokens_ids.setter
|
|
||||||
def additional_special_tokens_ids(self, values):
|
|
||||||
self._additional_special_tokens = [self.convert_ids_to_tokens(value) for value in values]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
|
def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
|
||||||
@ -1334,7 +1130,7 @@ class SpecialTokensMixin:
|
|||||||
"""
|
"""
|
||||||
set_attr = {}
|
set_attr = {}
|
||||||
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
|
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||||
attr_value = getattr(self, "_" + attr)
|
attr_value = self._special_tokens_map[attr]
|
||||||
if attr_value:
|
if attr_value:
|
||||||
set_attr[attr] = attr_value
|
set_attr[attr] = attr_value
|
||||||
return set_attr
|
return set_attr
|
||||||
@ -1379,6 +1175,20 @@ class SpecialTokensMixin:
|
|||||||
all_ids = self.convert_tokens_to_ids(all_toks)
|
all_ids = self.convert_tokens_to_ids(all_toks)
|
||||||
return all_ids
|
return all_ids
|
||||||
|
|
||||||
|
def _set_model_specific_special_tokens(self, special_tokens: List[str]):
|
||||||
|
"""
|
||||||
|
Adds new special tokens to the "SPECIAL_TOKENS_ATTRIBUTES" list which will be part
|
||||||
|
of "self.special_tokens" and saved as a special token in tokenizer's config.
|
||||||
|
This allows us to dynamically add new model-type specific tokens after initilizing the tokenizer.
|
||||||
|
For example: if the model tokenizers is multimodal, we can support special image or audio tokens.
|
||||||
|
"""
|
||||||
|
self.SPECIAL_TOKENS_ATTRIBUTES = self.SPECIAL_TOKENS_ATTRIBUTES + list(special_tokens.keys())
|
||||||
|
for key, value in special_tokens.items():
|
||||||
|
if isinstance(value, (str, AddedToken)):
|
||||||
|
self._special_tokens_map[key] = value
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}")
|
||||||
|
|
||||||
|
|
||||||
ENCODE_KWARGS_DOCSTRING = r"""
|
ENCODE_KWARGS_DOCSTRING = r"""
|
||||||
add_special_tokens (`bool`, *optional*, defaults to `True`):
|
add_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||||
@ -1633,6 +1443,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.extra_special_tokens = kwargs.pop("extra_special_tokens", {})
|
||||||
|
self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_len_single_sentence(self) -> int:
|
def max_len_single_sentence(self) -> int:
|
||||||
"""
|
"""
|
||||||
@ -2591,8 +2404,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
if hasattr(self, k):
|
if hasattr(self, k):
|
||||||
tokenizer_config[k] = getattr(self, k)
|
tokenizer_config[k] = getattr(self, k)
|
||||||
|
|
||||||
# Let's make sure we properly save the special tokens.
|
# Let's make sure we properly save the special tokens
|
||||||
tokenizer_config.update(self.special_tokens_map)
|
tokenizer_config.update(self.special_tokens_map)
|
||||||
|
if "extra_special_tokens" not in tokenizer_config:
|
||||||
|
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
|
||||||
|
tokenizer_config.update(self.extra_special_tokens)
|
||||||
|
|
||||||
if self.chat_template is not None:
|
if self.chat_template is not None:
|
||||||
if isinstance(self.chat_template, dict):
|
if isinstance(self.chat_template, dict):
|
||||||
|
@ -863,13 +863,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
|
special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
|
||||||
special_tokens_list.remove("additional_special_tokens")
|
special_tokens_list.remove("additional_special_tokens")
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
if getattr(self, token) is not None:
|
||||||
if getattr(self, f"_{token}") is not None:
|
|
||||||
special_token = getattr(self, token)
|
special_token = getattr(self, token)
|
||||||
if special_tokens_map is not None and special_token in special_tokens_map:
|
if special_tokens_map is not None and special_token in special_tokens_map:
|
||||||
special_token = special_tokens_map[special_token]
|
special_token = special_tokens_map[special_token]
|
||||||
|
|
||||||
special_token_full = getattr(self, f"_{token}")
|
special_token_full = self._special_tokens_map.get(token, None)
|
||||||
if isinstance(special_token_full, AddedToken):
|
if isinstance(special_token_full, AddedToken):
|
||||||
# Create an added token with the same parameters except the content
|
# Create an added token with the same parameters except the content
|
||||||
kwargs[token] = AddedToken(
|
kwargs[token] = AddedToken(
|
||||||
|
@ -154,7 +154,7 @@ class CamembertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
|
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
|
||||||
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
|
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
|
||||||
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
|
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
|
||||||
self.assertEqual(tokenizer._eos_token, new_eos)
|
self.assertEqual(tokenizer._special_tokens_map["eos_token"], new_eos)
|
||||||
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
|
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_2:
|
with tempfile.TemporaryDirectory() as tmp_dir_2:
|
||||||
@ -194,7 +194,7 @@ class CamembertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(
|
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(
|
||||||
pretrained_name, eos_token=new_eos, from_slow=True
|
pretrained_name, eos_token=new_eos, from_slow=True
|
||||||
)
|
)
|
||||||
self.assertEqual(tokenizer_fast._eos_token, new_eos)
|
self.assertEqual(tokenizer_fast._special_tokens_map["eos_token"], new_eos)
|
||||||
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
|
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
|
||||||
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
|
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
|
||||||
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):
|
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):
|
||||||
|
@ -1659,7 +1659,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
special_tokens_map = {}
|
special_tokens_map = {}
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is not None:
|
if getattr(tokenizer, token) is not None:
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
special_tokens_map[special_token] = f"{special_token}a"
|
special_tokens_map[special_token] = f"{special_token}a"
|
||||||
|
|
||||||
@ -1671,7 +1671,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Check the changes
|
# Check the changes
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is None:
|
if getattr(tokenizer, token) is None:
|
||||||
continue
|
continue
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
if special_token in special_tokens_map:
|
if special_token in special_tokens_map:
|
||||||
|
@ -1537,7 +1537,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
special_tokens_map = {}
|
special_tokens_map = {}
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is not None:
|
if getattr(tokenizer, token) is not None:
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
special_tokens_map[special_token] = f"{special_token}a"
|
special_tokens_map[special_token] = f"{special_token}a"
|
||||||
|
|
||||||
@ -1549,7 +1549,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Check the changes
|
# Check the changes
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is None:
|
if getattr(tokenizer, token) is None:
|
||||||
continue
|
continue
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
if special_token in special_tokens_map:
|
if special_token in special_tokens_map:
|
||||||
|
@ -1588,7 +1588,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
special_tokens_map = {}
|
special_tokens_map = {}
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is not None:
|
if getattr(tokenizer, token) is not None:
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
special_tokens_map[special_token] = f"{special_token}a"
|
special_tokens_map[special_token] = f"{special_token}a"
|
||||||
|
|
||||||
@ -1600,7 +1600,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Check the changes
|
# Check the changes
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is None:
|
if getattr(tokenizer, token) is None:
|
||||||
continue
|
continue
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
if special_token in special_tokens_map:
|
if special_token in special_tokens_map:
|
||||||
|
@ -385,6 +385,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
assert fast == [1, 319, 4559, 1243]
|
assert fast == [1, 319, 4559, 1243]
|
||||||
|
|
||||||
fast_tokenizer.add_eos_token = True
|
fast_tokenizer.add_eos_token = True
|
||||||
|
print(fast_tokenizer.add_eos_token)
|
||||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||||
assert fast == [1, 319, 4559, 1243, 2]
|
assert fast == [1, 319, 4559, 1243, 2]
|
||||||
|
|
||||||
|
@ -1435,7 +1435,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
special_tokens_map = {}
|
special_tokens_map = {}
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is not None:
|
if getattr(tokenizer, token) is not None:
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
special_tokens_map[special_token] = f"{special_token}a"
|
special_tokens_map[special_token] = f"{special_token}a"
|
||||||
|
|
||||||
@ -1447,7 +1447,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Check the changes
|
# Check the changes
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is None:
|
if getattr(tokenizer, token) is None:
|
||||||
continue
|
continue
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
if special_token in special_tokens_map:
|
if special_token in special_tokens_map:
|
||||||
|
@ -237,7 +237,7 @@ class MoshiTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
special_tokens_map = {}
|
special_tokens_map = {}
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is not None:
|
if getattr(tokenizer, token) is not None:
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
special_tokens_map[special_token] = f"{special_token}a"
|
special_tokens_map[special_token] = f"{special_token}a"
|
||||||
|
|
||||||
@ -249,7 +249,7 @@ class MoshiTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Check the changes
|
# Check the changes
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is None:
|
if getattr(tokenizer, token) is None:
|
||||||
continue
|
continue
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
if special_token in special_tokens_map:
|
if special_token in special_tokens_map:
|
||||||
|
@ -185,7 +185,7 @@ class RemBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
|
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
|
||||||
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
|
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
|
||||||
self.assertEqual(tokenizer._eos_token, new_eos)
|
self.assertEqual(tokenizer._special_tokens_map["eos_token"], new_eos)
|
||||||
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
|
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_2:
|
with tempfile.TemporaryDirectory() as tmp_dir_2:
|
||||||
@ -223,7 +223,7 @@ class RemBertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
with self.subTest("Hub -> Fast: Test loading a fast tokenizer from the hub)"):
|
with self.subTest("Hub -> Fast: Test loading a fast tokenizer from the hub)"):
|
||||||
if self.rust_tokenizer_class is not None:
|
if self.rust_tokenizer_class is not None:
|
||||||
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
|
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
|
||||||
self.assertEqual(tokenizer_fast._eos_token, new_eos)
|
self.assertEqual(tokenizer_fast._special_tokens_map["eos_token"], new_eos)
|
||||||
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
|
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
|
||||||
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
|
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
|
||||||
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):
|
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):
|
||||||
|
@ -1538,7 +1538,7 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
special_tokens_map = {}
|
special_tokens_map = {}
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is not None:
|
if getattr(tokenizer, token) is not None:
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
special_tokens_map[special_token] = f"{special_token}a"
|
special_tokens_map[special_token] = f"{special_token}a"
|
||||||
|
|
||||||
@ -1550,7 +1550,7 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
# Check the changes
|
# Check the changes
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is None:
|
if getattr(tokenizer, token) is None:
|
||||||
continue
|
continue
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
if special_token in special_tokens_map:
|
if special_token in special_tokens_map:
|
||||||
|
@ -4156,8 +4156,7 @@ class TokenizerTesterMixin:
|
|||||||
special_tokens_list.remove("additional_special_tokens")
|
special_tokens_list.remove("additional_special_tokens")
|
||||||
special_tokens_map = {}
|
special_tokens_map = {}
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
if getattr(tokenizer, token) is not None:
|
||||||
if getattr(tokenizer, f"_{token}") is not None:
|
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
special_tokens_map[special_token] = f"{special_token}a"
|
special_tokens_map[special_token] = f"{special_token}a"
|
||||||
|
|
||||||
@ -4169,7 +4168,7 @@ class TokenizerTesterMixin:
|
|||||||
# Check the changes
|
# Check the changes
|
||||||
for token in special_tokens_list:
|
for token in special_tokens_list:
|
||||||
# Get the private one to avoid unnecessary warnings.
|
# Get the private one to avoid unnecessary warnings.
|
||||||
if getattr(tokenizer, f"_{token}") is None:
|
if getattr(tokenizer, token) is None:
|
||||||
continue
|
continue
|
||||||
special_token = getattr(tokenizer, token)
|
special_token = getattr(tokenizer, token)
|
||||||
if special_token in special_tokens_map:
|
if special_token in special_tokens_map:
|
||||||
@ -4411,7 +4410,7 @@ class TokenizerTesterMixin:
|
|||||||
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
|
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
|
||||||
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
|
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
|
||||||
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
|
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
|
||||||
self.assertEqual(tokenizer._eos_token, new_eos)
|
self.assertEqual(tokenizer._special_tokens_map["eos_token"], new_eos)
|
||||||
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
|
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir_2:
|
with tempfile.TemporaryDirectory() as tmp_dir_2:
|
||||||
@ -4449,7 +4448,7 @@ class TokenizerTesterMixin:
|
|||||||
with self.subTest("Hub -> Fast: Test loading a fast tokenizer from the hub)"):
|
with self.subTest("Hub -> Fast: Test loading a fast tokenizer from the hub)"):
|
||||||
if self.rust_tokenizer_class is not None:
|
if self.rust_tokenizer_class is not None:
|
||||||
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
|
tokenizer_fast = self.rust_tokenizer_class.from_pretrained(pretrained_name, eos_token=new_eos)
|
||||||
self.assertEqual(tokenizer_fast._eos_token, new_eos)
|
self.assertEqual(tokenizer_fast._special_tokens_map["eos_token"], new_eos)
|
||||||
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
|
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
|
||||||
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
|
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
|
||||||
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):
|
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):
|
||||||
|
@ -28,6 +28,7 @@ from transformers import (
|
|||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
BertTokenizerFast,
|
BertTokenizerFast,
|
||||||
|
LlamaTokenizerFast,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
TensorType,
|
TensorType,
|
||||||
@ -280,6 +281,54 @@ class TokenizerUtilsTest(unittest.TestCase):
|
|||||||
self.assertEqual(decoded_flat, "##:")
|
self.assertEqual(decoded_flat, "##:")
|
||||||
self.assertEqual(decoded_list, "##:")
|
self.assertEqual(decoded_list, "##:")
|
||||||
|
|
||||||
|
def test_extra_special_tokens_multimodal(self):
|
||||||
|
special_tokens_list = [
|
||||||
|
"bos_token",
|
||||||
|
"eos_token",
|
||||||
|
"unk_token",
|
||||||
|
"sep_token",
|
||||||
|
"pad_token",
|
||||||
|
"cls_token",
|
||||||
|
"mask_token",
|
||||||
|
"additional_special_tokens",
|
||||||
|
]
|
||||||
|
llama_tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
|
||||||
|
llama_tokenizer.extra_special_tokens = {
|
||||||
|
"boi_token": "<image_start>",
|
||||||
|
"eoi_token": "<image_end>",
|
||||||
|
"image_token": "<image>",
|
||||||
|
}
|
||||||
|
self.assertListEqual(llama_tokenizer.SPECIAL_TOKENS_ATTRIBUTES, special_tokens_list)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
llama_tokenizer.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
# load back and check we have extra special tokens set
|
||||||
|
loaded_tokenizer = LlamaTokenizerFast.from_pretrained(tmpdirname)
|
||||||
|
multimodal_special_tokens_list = special_tokens_list + ["boi_token", "eoi_token", "image_token"]
|
||||||
|
self.assertListEqual(loaded_tokenizer.SPECIAL_TOKENS_ATTRIBUTES, multimodal_special_tokens_list)
|
||||||
|
|
||||||
|
# We set an image_token_id before, so we can get an "image_token" as str that matches the id
|
||||||
|
self.assertTrue(loaded_tokenizer.image_token == "<image>")
|
||||||
|
self.assertTrue(loaded_tokenizer.image_token_id == loaded_tokenizer.convert_tokens_to_ids("<image>"))
|
||||||
|
|
||||||
|
# save one more time and make sure the image token can get loaded back
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
loaded_tokenizer.save_pretrained(tmpdirname)
|
||||||
|
loaded_tokenizer_with_extra_tokens = LlamaTokenizerFast.from_pretrained(tmpdirname)
|
||||||
|
self.assertTrue(loaded_tokenizer_with_extra_tokens.image_token == "<image>")
|
||||||
|
|
||||||
|
# test that we can also indicate extra tokens during load time
|
||||||
|
extra_special_tokens = {
|
||||||
|
"boi_token": "<image_start>",
|
||||||
|
"eoi_token": "<image_end>",
|
||||||
|
"image_token": "<image>",
|
||||||
|
}
|
||||||
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
"huggyllama/llama-7b", extra_special_tokens=extra_special_tokens
|
||||||
|
)
|
||||||
|
self.assertTrue(tokenizer.image_token == "<image>")
|
||||||
|
self.assertTrue(tokenizer.image_token_id == loaded_tokenizer.convert_tokens_to_ids("<image>"))
|
||||||
|
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
def test_decoding_skip_special_tokens(self):
|
def test_decoding_skip_special_tokens(self):
|
||||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
|
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
|
||||||
|
@ -299,7 +299,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
||||||
|
|
||||||
tokenizer._pad_token = None
|
tokenizer.pad_token = None
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
# Expect error due to padding token missing
|
# Expect error due to padding token missing
|
||||||
@ -978,7 +978,7 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 16])
|
self.assertEqual(batch["input_ids"].shape.as_list(), [2, 16])
|
||||||
self.assertEqual(batch["labels"].shape.as_list(), [2, 16])
|
self.assertEqual(batch["labels"].shape.as_list(), [2, 16])
|
||||||
|
|
||||||
tokenizer._pad_token = None
|
tokenizer.pad_token = None
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf")
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf")
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
# Expect error due to padding token missing
|
# Expect error due to padding token missing
|
||||||
@ -1673,7 +1673,7 @@ class NumpyDataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["input_ids"].shape, (2, 16))
|
self.assertEqual(batch["input_ids"].shape, (2, 16))
|
||||||
self.assertEqual(batch["labels"].shape, (2, 16))
|
self.assertEqual(batch["labels"].shape, (2, 16))
|
||||||
|
|
||||||
tokenizer._pad_token = None
|
tokenizer.pad_token = None
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="np")
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="np")
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
# Expect error due to padding token missing
|
# Expect error due to padding token missing
|
||||||
|
Loading…
Reference in New Issue
Block a user