wip: add internally some methods not being supported in mistral-common

This commit is contained in:
Julien Denize 2025-06-25 16:21:24 +02:00
parent befe52ae55
commit ed80fefcaf
2 changed files with 61 additions and 32 deletions

View File

@ -16,13 +16,15 @@ import os
import shutil
import warnings
from collections.abc import Mapping, Sized
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union, overload
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union, overload
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.utils import download_tokenizer_from_hf_hub
from transformers.tokenization_utils_base import (
@ -141,6 +143,13 @@ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
"""
class MistralTokenizerType(str, Enum):
"""Enum for the different type of tokenizer."""
spm = "spm"
tekken = "tekken"
@requires(backends=("mistral-common",))
class MistralCommonTokenizer(PushToHubMixin):
"""
@ -232,6 +241,11 @@ class MistralCommonTokenizer(PushToHubMixin):
self._tokenizer_path = Path(tokenizer_path)
self.tokenizer: MistralTokenizer = MistralTokenizer.from_file(str(self._tokenizer_path), mode=mode)
self._tokenizer_type = (
MistralTokenizerType.tekken
if isinstance(self.tokenizer.instruct_tokenizer.tokenizer, Tekkenizer)
else MistralTokenizerType.spm
)
self.truncation_side = truncation_side
self.padding_side = padding_side
self.model_max_length = model_max_length
@ -476,6 +490,14 @@ class MistralCommonTokenizer(PushToHubMixin):
for seq in sequences
]
def _is_control_token(self, token_id: int) -> bool:
if self._tokenizer_type == MistralTokenizerType.spm:
return token_id in self.tokenizer.instruct_tokenizer.tokenizer._control_tokens()
elif self._tokenizer_type == MistralTokenizerType.tekken:
return token_id < self.tokenizer.instruct_tokenizer.tokenizer.num_special_tokens
else:
raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")
@overload
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
@overload
@ -505,7 +527,7 @@ class MistralCommonTokenizer(PushToHubMixin):
tokens: list[str] = []
for token_id in ids:
if self.tokenizer.instruct_tokenizer.tokenizer.is_control_token(token_id) and skip_special_tokens:
if self._is_control_token(token_id) and skip_special_tokens:
continue
tokens.append(self.tokenizer.instruct_tokenizer.tokenizer.id_to_piece(token_id))
@ -516,6 +538,18 @@ class MistralCommonTokenizer(PushToHubMixin):
return tokens[0]
return tokens
def _piece_to_id(self, piece: str) -> int:
if self._tokenizer_type == MistralTokenizerType.spm:
return self.tokenizer.instruct_tokenizer.tokenizer._model.piece_to_id(piece)
elif self._tokenizer_type == MistralTokenizerType.tekken:
pieces = self.tokenizer.instruct_tokenizer.tokenizer._model.encode(
piece, allowed_special="all", disallowed_special=set()
)
assert len(pieces) == 1, f"Expected to decode 1 token, got {len(pieces)}"
return pieces[0]
else:
raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")
def convert_tokens_to_ids(self, tokens: Union[str, list[str]]) -> Union[int, list[int]]:
"""
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
@ -536,7 +570,7 @@ class MistralCommonTokenizer(PushToHubMixin):
ids: list[int] = []
for token in tokens:
ids.append(self.tokenizer.instruct_tokenizer.tokenizer.piece_to_id(token))
ids.append(self._piece_to_id(token))
if one_token:
return ids[0]
@ -685,6 +719,14 @@ class MistralCommonTokenizer(PushToHubMixin):
return BatchEncoding(batch_outputs)
def _all_special_ids(self) -> Set[int]:
if self._tokenizer_type == MistralTokenizerType.tekken:
return {t["rank"] for t in self.tokenizer.instruct_tokenizer.tokenizer._all_special_tokens}
elif self._tokenizer_type == MistralTokenizerType.spm:
return self.tokenizer.instruct_tokenizer.tokenizer._control_tokens()
else:
raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")
def get_special_tokens_mask(
self, token_ids_0: list, token_ids_1: None = None, already_has_special_tokens: bool = False
) -> list[int]:
@ -712,7 +754,7 @@ class MistralCommonTokenizer(PushToHubMixin):
"`already_has_special_tokens` is not supported by `MistralCommonTokenizer` and should be `False`."
)
all_special_ids = self.tokenizer.instruct_tokenizer.tokenizer.all_special_ids # cache the property
all_special_ids = self._all_special_ids() # cache the ids
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
return special_tokens_mask

View File

@ -39,6 +39,15 @@ class TestMistralCommonTokenizer(unittest.TestCase):
for conversation in cls.fixture_conversations
]
cls.ref_special_ids = {t["rank"] for t in cls.ref_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens}
def _ref_piece_to_id(self, piece: str) -> int:
pieces = self.ref_tokenizer.instruct_tokenizer.tokenizer._model.encode(
piece, allowed_special="all", disallowed_special=set()
)
assert len(pieces) == 1, f"Expected to decode 1 token, got {len(pieces)}"
return pieces[0]
def test_vocab_size(self):
self.assertEqual(self.tokenizer.vocab_size, self.ref_tokenizer.instruct_tokenizer.tokenizer.n_words)
@ -223,7 +232,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
def test_convert_tokens_to_ids(self):
tokens = ["Hello", "world", "!"]
expected_ids = [self.ref_tokenizer.instruct_tokenizer.tokenizer.piece_to_id(token) for token in tokens]
expected_ids = [self._ref_piece_to_id(token) for token in tokens]
# Test 1:
# list of tokens
ids = self.tokenizer.convert_tokens_to_ids(tokens)
@ -253,9 +262,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
# Test 1:
# with skip_special_tokens=False
ids = self.ref_tokenizer.instruct_tokenizer.tokenizer.encode("Hello world!", bos=True, eos=True)
expected_mask = [
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0 for id in ids
]
expected_mask = [1 if id in self.ref_special_ids else 0 for id in ids]
mask = self.tokenizer.get_special_tokens_mask(ids)
self.assertEqual(mask, expected_mask)
@ -1267,13 +1274,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
self.assertEqual(tokens["attention_mask"], [[1] * min(len(t), 10) for t in expected_tokens])
self.assertEqual(
tokens["special_tokens_mask"],
[
[
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0
for id in ids[:10]
]
for ids in expected_tokens
],
[[1 if id in self.ref_special_ids else 0 for id in ids[:10]] for ids in expected_tokens],
)
# Test 2:
@ -1313,13 +1314,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
)
self.assertEqual(
tokens["special_tokens_mask"],
[
[
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0
for id in ids[:10]
]
for ids in expected_tokens
],
[[1 if id in self.ref_special_ids else 0 for id in ids[:10]] for ids in expected_tokens],
)
def test_batch_call_with_padding(self):
@ -1477,11 +1472,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
self.assertEqual(
tokens["special_tokens_mask"],
[
num_padding[i] * [1]
+ [
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0
for id in ids[:10]
]
num_padding[i] * [1] + [1 if id in self.ref_special_ids else 0 for id in ids[:10]]
for i, ids in enumerate(expected_tokens)
],
)
@ -1505,11 +1496,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
self.assertEqual(
tokens["special_tokens_mask"],
[
num_padding[i] * [1]
+ [
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0
for id in ids
]
num_padding[i] * [1] + [1 if id in self.ref_special_ids else 0 for id in ids]
for i, ids in enumerate(expected_tokens)
],
)