This commit is contained in:
Julien Denize 2025-07-02 19:08:50 +02:00 committed by GitHub
commit e7ab7105dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 3309 additions and 7 deletions

View File

@ -204,6 +204,7 @@ _deps = [
"opentelemetry-api", "opentelemetry-api",
"opentelemetry-exporter-otlp", "opentelemetry-exporter-otlp",
"opentelemetry-sdk", "opentelemetry-sdk",
"mistral-common[open-cv]>=1.6.3",
] ]
@ -334,6 +335,7 @@ extras["video"] = deps_list("av")
extras["num2words"] = deps_list("num2words") extras["num2words"] = deps_list("num2words")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["tiktoken"] = deps_list("tiktoken", "blobfile") extras["tiktoken"] = deps_list("tiktoken", "blobfile")
extras["mistral-common"] = deps_list("mistral-common[open-cv]")
extras["testing"] = ( extras["testing"] = (
deps_list( deps_list(
"pytest", "pytest",
@ -384,6 +386,7 @@ extras["all"] = (
+ extras["accelerate"] + extras["accelerate"]
+ extras["video"] + extras["video"]
+ extras["num2words"] + extras["num2words"]
+ extras["mistral-common"]
) )

View File

@ -106,4 +106,5 @@ deps = {
"opentelemetry-api": "opentelemetry-api", "opentelemetry-api": "opentelemetry-api",
"opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp", "opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp",
"opentelemetry-sdk": "opentelemetry-sdk", "opentelemetry-sdk": "opentelemetry-sdk",
"mistral-common": "mistral-common>=1.6.3",
} }

View File

@ -21,6 +21,8 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional, Union from typing import Any, Optional, Union
from transformers.utils.import_utils import is_mistral_common_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint from ...modeling_gguf_pytorch_utils import load_gguf_checkpoint
@ -373,15 +375,19 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
( (
"mistral", "mistral",
( (
"LlamaTokenizer" if is_sentencepiece_available() else None, "MistralCommonTokenizer"
"LlamaTokenizerFast" if is_tokenizers_available() else None, if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
), ),
), ),
( (
"mixtral", "mixtral",
( (
"LlamaTokenizer" if is_sentencepiece_available() else None, "MistralCommonTokenizer"
"LlamaTokenizerFast" if is_tokenizers_available() else None, if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
), ),
), ),
("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("mllama", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
@ -476,7 +482,15 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)), ("phobert", ("PhobertTokenizer", None)),
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
("pixtral", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), (
"pixtral",
(
None,
"MistralCommonTokenizer"
if is_mistral_common_available()
else ("PreTrainedTokenizerFast" if is_tokenizers_available() else None),
),
),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)), ("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)), ("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
@ -706,8 +720,10 @@ def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if class_name in tokenizers: if class_name in tokenizers:
module_name = model_type_to_module_name(module_name) module_name = model_type_to_module_name(module_name)
if module_name in ["mistral", "mixtral"] and class_name == "MistralCommonTokenizer":
module = importlib.import_module(f".{module_name}", "transformers.models") module = importlib.import_module(".tokenization_mistral_common", "transformers")
else:
module = importlib.import_module(f".{module_name}", "transformers.models")
try: try:
return getattr(module, class_name) return getattr(module, class_name)
except AttributeError: except AttributeError:

File diff suppressed because it is too large Load Diff

View File

@ -227,6 +227,7 @@ _spqr_available = _is_package_available("spqr_quant")
_rich_available = _is_package_available("rich") _rich_available = _is_package_available("rich")
_kernels_available = _is_package_available("kernels") _kernels_available = _is_package_available("kernels")
_matplotlib_available = _is_package_available("matplotlib") _matplotlib_available = _is_package_available("matplotlib")
_mistral_common_available = _is_package_available("mistral_common")
_torch_version = "N/A" _torch_version = "N/A"
_torch_available = False _torch_available = False
@ -1566,6 +1567,10 @@ def is_matplotlib_available():
return _matplotlib_available return _matplotlib_available
def is_mistral_common_available():
return _mistral_common_available
def check_torch_load_is_safe(): def check_torch_load_is_safe():
if not is_torch_greater_or_equal("2.6"): if not is_torch_greater_or_equal("2.6"):
raise ValueError( raise ValueError(
@ -1970,6 +1975,11 @@ RICH_IMPORT_ERROR = """
rich`. Please note that you may need to restart your runtime after installation. rich`. Please note that you may need to restart your runtime after installation.
""" """
MISTRAL_COMMON_IMPORT_ERROR = """
{0} requires the mistral-common library but it was not found in your environment. You can install it with pip: `pip install mistral-common`. Please note that you may need to restart your runtime after installation.
"""
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
[ [
("av", (is_av_available, AV_IMPORT_ERROR)), ("av", (is_av_available, AV_IMPORT_ERROR)),
@ -2022,6 +2032,7 @@ BACKENDS_MAPPING = OrderedDict(
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)), ("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)), ("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)), ("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)),
] ]
) )

File diff suppressed because it is too large Load Diff