This commit is contained in:
Julien Denize 2025-07-01 14:40:32 +02:00
parent ed80fefcaf
commit a4ddafb190
3 changed files with 34 additions and 8 deletions

View File

@ -482,7 +482,15 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
(
"pixtral",
(
None,
"MistralCommonTokenizer"
if is_mistral_common_available()
else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
),
),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),

View File

@ -263,6 +263,8 @@ class MistralCommonTokenizer(PushToHubMixin):
)
self.model_input_names = model_input_names
self._cache_get_vocab: Optional[Dict[str, int]] = None
@property
def bos_token_id(self) -> int:
"""
@ -338,7 +340,11 @@ class MistralCommonTokenizer(PushToHubMixin):
Returns:
`Dict[str, int]`: The vocabulary.
"""
return {token: idx for idx, token in enumerate(self.tokenizer.instruct_tokenizer.tokenizer.vocab())}
if self._cache_get_vocab is None:
self._cache_get_vocab = {
token: idx for idx, token in enumerate(self.tokenizer.instruct_tokenizer.tokenizer.vocab())
}
return self._cache_get_vocab
def __len__(self):
"""
@ -576,7 +582,7 @@ class MistralCommonTokenizer(PushToHubMixin):
return ids[0]
return ids
def _tokenize_ids(self, text: TextInput, add_special_tokens: bool) -> list[int]:
def _text_to_ids(self, text: TextInput, add_special_tokens: bool) -> list[int]:
"""
Converts a string into a sequence of tokens ids, using the tokenizer.
"""
@ -604,9 +610,7 @@ class MistralCommonTokenizer(PushToHubMixin):
if kwargs:
raise ValueError(f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.tokenize`.")
return self.convert_ids_to_tokens(
self._tokenize_ids(text, add_special_tokens=False), skip_special_tokens=False
)
return self.convert_ids_to_tokens(self._text_to_ids(text, add_special_tokens=False), skip_special_tokens=False)
def _encode_plus(
self,
@ -633,7 +637,7 @@ class MistralCommonTokenizer(PushToHubMixin):
def get_input_ids(text):
if isinstance(text, str):
return self._tokenize_ids(text, add_special_tokens)
return self._text_to_ids(text, add_special_tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text
else:
@ -683,7 +687,7 @@ class MistralCommonTokenizer(PushToHubMixin):
) -> BatchEncoding:
def get_input_ids(text):
if isinstance(text, str):
return self._tokenize_ids(text, add_special_tokens)
return self._text_to_ids(text, add_special_tokens)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text
else:

View File

@ -1,3 +1,17 @@
# Copyright 2025 Mistral AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest