mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
wip: add internally some methods not being supported in mistral-common
This commit is contained in:
parent
befe52ae55
commit
ed80fefcaf
@ -16,13 +16,15 @@ import os
|
||||
import shutil
|
||||
import warnings
|
||||
from collections.abc import Mapping, Sized
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union, overload
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union, overload
|
||||
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.instruct.validator import ValidationMode
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.utils import download_tokenizer_from_hf_hub
|
||||
|
||||
from transformers.tokenization_utils_base import (
|
||||
@ -141,6 +143,13 @@ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
class MistralTokenizerType(str, Enum):
|
||||
"""Enum for the different type of tokenizer."""
|
||||
|
||||
spm = "spm"
|
||||
tekken = "tekken"
|
||||
|
||||
|
||||
@requires(backends=("mistral-common",))
|
||||
class MistralCommonTokenizer(PushToHubMixin):
|
||||
"""
|
||||
@ -232,6 +241,11 @@ class MistralCommonTokenizer(PushToHubMixin):
|
||||
|
||||
self._tokenizer_path = Path(tokenizer_path)
|
||||
self.tokenizer: MistralTokenizer = MistralTokenizer.from_file(str(self._tokenizer_path), mode=mode)
|
||||
self._tokenizer_type = (
|
||||
MistralTokenizerType.tekken
|
||||
if isinstance(self.tokenizer.instruct_tokenizer.tokenizer, Tekkenizer)
|
||||
else MistralTokenizerType.spm
|
||||
)
|
||||
self.truncation_side = truncation_side
|
||||
self.padding_side = padding_side
|
||||
self.model_max_length = model_max_length
|
||||
@ -476,6 +490,14 @@ class MistralCommonTokenizer(PushToHubMixin):
|
||||
for seq in sequences
|
||||
]
|
||||
|
||||
def _is_control_token(self, token_id: int) -> bool:
|
||||
if self._tokenizer_type == MistralTokenizerType.spm:
|
||||
return token_id in self.tokenizer.instruct_tokenizer.tokenizer._control_tokens()
|
||||
elif self._tokenizer_type == MistralTokenizerType.tekken:
|
||||
return token_id < self.tokenizer.instruct_tokenizer.tokenizer.num_special_tokens
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")
|
||||
|
||||
@overload
|
||||
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
|
||||
@overload
|
||||
@ -505,7 +527,7 @@ class MistralCommonTokenizer(PushToHubMixin):
|
||||
|
||||
tokens: list[str] = []
|
||||
for token_id in ids:
|
||||
if self.tokenizer.instruct_tokenizer.tokenizer.is_control_token(token_id) and skip_special_tokens:
|
||||
if self._is_control_token(token_id) and skip_special_tokens:
|
||||
continue
|
||||
tokens.append(self.tokenizer.instruct_tokenizer.tokenizer.id_to_piece(token_id))
|
||||
|
||||
@ -516,6 +538,18 @@ class MistralCommonTokenizer(PushToHubMixin):
|
||||
return tokens[0]
|
||||
return tokens
|
||||
|
||||
def _piece_to_id(self, piece: str) -> int:
|
||||
if self._tokenizer_type == MistralTokenizerType.spm:
|
||||
return self.tokenizer.instruct_tokenizer.tokenizer._model.piece_to_id(piece)
|
||||
elif self._tokenizer_type == MistralTokenizerType.tekken:
|
||||
pieces = self.tokenizer.instruct_tokenizer.tokenizer._model.encode(
|
||||
piece, allowed_special="all", disallowed_special=set()
|
||||
)
|
||||
assert len(pieces) == 1, f"Expected to decode 1 token, got {len(pieces)}"
|
||||
return pieces[0]
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")
|
||||
|
||||
def convert_tokens_to_ids(self, tokens: Union[str, list[str]]) -> Union[int, list[int]]:
|
||||
"""
|
||||
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
|
||||
@ -536,7 +570,7 @@ class MistralCommonTokenizer(PushToHubMixin):
|
||||
|
||||
ids: list[int] = []
|
||||
for token in tokens:
|
||||
ids.append(self.tokenizer.instruct_tokenizer.tokenizer.piece_to_id(token))
|
||||
ids.append(self._piece_to_id(token))
|
||||
|
||||
if one_token:
|
||||
return ids[0]
|
||||
@ -685,6 +719,14 @@ class MistralCommonTokenizer(PushToHubMixin):
|
||||
|
||||
return BatchEncoding(batch_outputs)
|
||||
|
||||
def _all_special_ids(self) -> Set[int]:
|
||||
if self._tokenizer_type == MistralTokenizerType.tekken:
|
||||
return {t["rank"] for t in self.tokenizer.instruct_tokenizer.tokenizer._all_special_tokens}
|
||||
elif self._tokenizer_type == MistralTokenizerType.spm:
|
||||
return self.tokenizer.instruct_tokenizer.tokenizer._control_tokens()
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: list, token_ids_1: None = None, already_has_special_tokens: bool = False
|
||||
) -> list[int]:
|
||||
@ -712,7 +754,7 @@ class MistralCommonTokenizer(PushToHubMixin):
|
||||
"`already_has_special_tokens` is not supported by `MistralCommonTokenizer` and should be `False`."
|
||||
)
|
||||
|
||||
all_special_ids = self.tokenizer.instruct_tokenizer.tokenizer.all_special_ids # cache the property
|
||||
all_special_ids = self._all_special_ids() # cache the ids
|
||||
|
||||
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
|
||||
return special_tokens_mask
|
||||
|
@ -39,6 +39,15 @@ class TestMistralCommonTokenizer(unittest.TestCase):
|
||||
for conversation in cls.fixture_conversations
|
||||
]
|
||||
|
||||
cls.ref_special_ids = {t["rank"] for t in cls.ref_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens}
|
||||
|
||||
def _ref_piece_to_id(self, piece: str) -> int:
|
||||
pieces = self.ref_tokenizer.instruct_tokenizer.tokenizer._model.encode(
|
||||
piece, allowed_special="all", disallowed_special=set()
|
||||
)
|
||||
assert len(pieces) == 1, f"Expected to decode 1 token, got {len(pieces)}"
|
||||
return pieces[0]
|
||||
|
||||
def test_vocab_size(self):
|
||||
self.assertEqual(self.tokenizer.vocab_size, self.ref_tokenizer.instruct_tokenizer.tokenizer.n_words)
|
||||
|
||||
@ -223,7 +232,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
|
||||
|
||||
def test_convert_tokens_to_ids(self):
|
||||
tokens = ["Hello", "world", "!"]
|
||||
expected_ids = [self.ref_tokenizer.instruct_tokenizer.tokenizer.piece_to_id(token) for token in tokens]
|
||||
expected_ids = [self._ref_piece_to_id(token) for token in tokens]
|
||||
# Test 1:
|
||||
# list of tokens
|
||||
ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
@ -253,9 +262,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
|
||||
# Test 1:
|
||||
# with skip_special_tokens=False
|
||||
ids = self.ref_tokenizer.instruct_tokenizer.tokenizer.encode("Hello world!", bos=True, eos=True)
|
||||
expected_mask = [
|
||||
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0 for id in ids
|
||||
]
|
||||
expected_mask = [1 if id in self.ref_special_ids else 0 for id in ids]
|
||||
|
||||
mask = self.tokenizer.get_special_tokens_mask(ids)
|
||||
self.assertEqual(mask, expected_mask)
|
||||
@ -1267,13 +1274,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
|
||||
self.assertEqual(tokens["attention_mask"], [[1] * min(len(t), 10) for t in expected_tokens])
|
||||
self.assertEqual(
|
||||
tokens["special_tokens_mask"],
|
||||
[
|
||||
[
|
||||
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0
|
||||
for id in ids[:10]
|
||||
]
|
||||
for ids in expected_tokens
|
||||
],
|
||||
[[1 if id in self.ref_special_ids else 0 for id in ids[:10]] for ids in expected_tokens],
|
||||
)
|
||||
|
||||
# Test 2:
|
||||
@ -1313,13 +1314,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(
|
||||
tokens["special_tokens_mask"],
|
||||
[
|
||||
[
|
||||
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0
|
||||
for id in ids[:10]
|
||||
]
|
||||
for ids in expected_tokens
|
||||
],
|
||||
[[1 if id in self.ref_special_ids else 0 for id in ids[:10]] for ids in expected_tokens],
|
||||
)
|
||||
|
||||
def test_batch_call_with_padding(self):
|
||||
@ -1477,11 +1472,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
tokens["special_tokens_mask"],
|
||||
[
|
||||
num_padding[i] * [1]
|
||||
+ [
|
||||
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0
|
||||
for id in ids[:10]
|
||||
]
|
||||
num_padding[i] * [1] + [1 if id in self.ref_special_ids else 0 for id in ids[:10]]
|
||||
for i, ids in enumerate(expected_tokens)
|
||||
],
|
||||
)
|
||||
@ -1505,11 +1496,7 @@ class TestMistralCommonTokenizer(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
tokens["special_tokens_mask"],
|
||||
[
|
||||
num_padding[i] * [1]
|
||||
+ [
|
||||
1 if id in self.ref_tokenizer.instruct_tokenizer.tokenizer.all_special_ids else 0
|
||||
for id in ids
|
||||
]
|
||||
num_padding[i] * [1] + [1 if id in self.ref_special_ids else 0 for id in ids]
|
||||
for i, ids in enumerate(expected_tokens)
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user