Merge type hints from microsoft/python-type-stubs (post dropping support for Python 3.8) (#38335)

* Merge type hints from microsoft/python-type-stubs (post Python 3.8)

* Remove mention of pylance

* Resolved conflict

* Merge type hints from microsoft/python-type-stubs (post Python 3.8)

* Remove mention of pylance

* Resolved conflict

* Update src/transformers/models/auto/configuration_auto.py

Co-authored-by: Avasam <samuel.06@hotmail.com>

---------

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Avasam 2025-05-28 12:21:40 -04:00 committed by GitHub
parent 942c60956f
commit 2872e8bac5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 670 additions and 661 deletions

View File

@ -17,8 +17,11 @@
import copy import copy
import importlib import importlib
import json import json
import os
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterator
from typing import Any, TypeVar, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
@ -42,6 +45,9 @@ if is_torch_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_T = TypeVar("_T")
# Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
_LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]]
CLASS_DOCSTRING = """ CLASS_DOCSTRING = """
This is a generic model class that will be instantiated as one of the model classes of the library when created This is a generic model class that will be instantiated as one of the model classes of the library when created
@ -408,7 +414,7 @@ class _BaseAutoModelClass:
# Base class for auto models. # Base class for auto models.
_model_mapping = None _model_mapping = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
raise EnvironmentError( raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated " f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
@ -456,7 +462,7 @@ class _BaseAutoModelClass:
return config return config
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs):
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
trust_remote_code = kwargs.get("trust_remote_code", None) trust_remote_code = kwargs.get("trust_remote_code", None)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
@ -592,7 +598,7 @@ class _BaseAutoModelClass:
) )
@classmethod @classmethod
def register(cls, config_class, model_class, exist_ok=False): def register(cls, config_class, model_class, exist_ok=False) -> None:
""" """
Register a new model for this class. Register a new model for this class.
@ -650,7 +656,7 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass):
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
def insert_head_doc(docstring, head_doc=""): def insert_head_doc(docstring, head_doc: str = ""):
if len(head_doc) > 0: if len(head_doc) > 0:
return docstring.replace( return docstring.replace(
"one of the model classes of the library ", "one of the model classes of the library ",
@ -661,7 +667,7 @@ def insert_head_doc(docstring, head_doc=""):
) )
def auto_class_update(cls, checkpoint_for_example="google-bert/bert-base-cased", head_doc=""): def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""):
# Create a new class with the right name from the base class # Create a new class with the right name from the base class
model_mapping = cls._model_mapping model_mapping = cls._model_mapping
name = cls.__name__ name = cls.__name__
@ -759,7 +765,7 @@ def add_generation_mixin_to_remote_model(model_class):
return model_class return model_class
class _LazyAutoMapping(OrderedDict): class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue]):
""" """
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
@ -768,7 +774,7 @@ class _LazyAutoMapping(OrderedDict):
- model_mapping: The map model type to model (or tokenizer) class - model_mapping: The map model type to model (or tokenizer) class
""" """
def __init__(self, config_mapping, model_mapping): def __init__(self, config_mapping, model_mapping) -> None:
self._config_mapping = config_mapping self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping self._model_mapping = model_mapping
@ -776,11 +782,11 @@ class _LazyAutoMapping(OrderedDict):
self._extra_content = {} self._extra_content = {}
self._modules = {} self._modules = {}
def __len__(self): def __len__(self) -> int:
common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys()) common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
return len(common_keys) + len(self._extra_content) return len(common_keys) + len(self._extra_content)
def __getitem__(self, key): def __getitem__(self, key: type[PretrainedConfig]) -> _LazyAutoMappingValue:
if key in self._extra_content: if key in self._extra_content:
return self._extra_content[key] return self._extra_content[key]
model_type = self._reverse_config_mapping[key.__name__] model_type = self._reverse_config_mapping[key.__name__]
@ -802,7 +808,7 @@ class _LazyAutoMapping(OrderedDict):
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattribute_from_module(self._modules[module_name], attr) return getattribute_from_module(self._modules[module_name], attr)
def keys(self): def keys(self) -> list[type[PretrainedConfig]]:
mapping_keys = [ mapping_keys = [
self._load_attr_from_module(key, name) self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items() for key, name in self._config_mapping.items()
@ -810,16 +816,16 @@ class _LazyAutoMapping(OrderedDict):
] ]
return mapping_keys + list(self._extra_content.keys()) return mapping_keys + list(self._extra_content.keys())
def get(self, key, default): def get(self, key: type[PretrainedConfig], default: _T) -> Union[_LazyAutoMappingValue, _T]:
try: try:
return self.__getitem__(key) return self.__getitem__(key)
except KeyError: except KeyError:
return default return default
def __bool__(self): def __bool__(self) -> bool:
return bool(self.keys()) return bool(self.keys())
def values(self): def values(self) -> list[_LazyAutoMappingValue]:
mapping_values = [ mapping_values = [
self._load_attr_from_module(key, name) self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items() for key, name in self._model_mapping.items()
@ -827,7 +833,7 @@ class _LazyAutoMapping(OrderedDict):
] ]
return mapping_values + list(self._extra_content.values()) return mapping_values + list(self._extra_content.values())
def items(self): def items(self) -> list[tuple[type[PretrainedConfig], _LazyAutoMappingValue]]:
mapping_items = [ mapping_items = [
( (
self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._config_mapping[key]),
@ -838,10 +844,10 @@ class _LazyAutoMapping(OrderedDict):
] ]
return mapping_items + list(self._extra_content.items()) return mapping_items + list(self._extra_content.items())
def __iter__(self): def __iter__(self) -> Iterator[type[PretrainedConfig]]:
return iter(self.keys()) return iter(self.keys())
def __contains__(self, item): def __contains__(self, item: type) -> bool:
if item in self._extra_content: if item in self._extra_content:
return True return True
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
@ -849,7 +855,7 @@ class _LazyAutoMapping(OrderedDict):
model_type = self._reverse_config_mapping[item.__name__] model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping return model_type in self._model_mapping
def register(self, key, value, exist_ok=False): def register(self, key: type[PretrainedConfig], value: _LazyAutoMappingValue, exist_ok=False) -> None:
""" """
Register a new model in this mapping. Register a new model in this mapping.
""" """

View File

@ -15,10 +15,12 @@
"""Auto Config class.""" """Auto Config class."""
import importlib import importlib
import os
import re import re
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import List, Union from collections.abc import Callable, Iterator, KeysView, ValuesView
from typing import Any, TypeVar, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
@ -28,7 +30,10 @@ from ...utils import CONFIG_NAME, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
CONFIG_MAPPING_NAMES = OrderedDict( _CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
CONFIG_MAPPING_NAMES = OrderedDict[str, str](
[ [
# Add configs here # Add configs here
("albert", "AlbertConfig"), ("albert", "AlbertConfig"),
@ -380,7 +385,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
) )
MODEL_NAMES_MAPPING = OrderedDict( MODEL_NAMES_MAPPING = OrderedDict[str, str](
[ [
# Add full (and cased) model names here # Add full (and cased) model names here
("albert", "ALBERT"), ("albert", "ALBERT"),
@ -795,7 +800,7 @@ DEPRECATED_MODELS = [
"xlm_prophetnet", "xlm_prophetnet",
] ]
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict( SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
[ [
("openai-gpt", "openai"), ("openai-gpt", "openai"),
("data2vec-audio", "data2vec"), ("data2vec-audio", "data2vec"),
@ -827,7 +832,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
) )
def model_type_to_module_name(key): def model_type_to_module_name(key) -> str:
"""Converts a config key to the corresponding module.""" """Converts a config key to the corresponding module."""
# Special treatment # Special treatment
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME: if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
@ -844,7 +849,7 @@ def model_type_to_module_name(key):
return key return key
def config_class_to_model_type(config): def config_class_to_model_type(config) -> Union[str, None]:
"""Converts a config class name to the corresponding model type""" """Converts a config class name to the corresponding model type"""
for key, cls in CONFIG_MAPPING_NAMES.items(): for key, cls in CONFIG_MAPPING_NAMES.items():
if cls == config: if cls == config:
@ -856,17 +861,17 @@ def config_class_to_model_type(config):
return None return None
class _LazyConfigMapping(OrderedDict): class _LazyConfigMapping(OrderedDict[str, type[PretrainedConfig]]):
""" """
A dictionary that lazily load its values when they are requested. A dictionary that lazily load its values when they are requested.
""" """
def __init__(self, mapping): def __init__(self, mapping) -> None:
self._mapping = mapping self._mapping = mapping
self._extra_content = {} self._extra_content = {}
self._modules = {} self._modules = {}
def __getitem__(self, key): def __getitem__(self, key: str) -> type[PretrainedConfig]:
if key in self._extra_content: if key in self._extra_content:
return self._extra_content[key] return self._extra_content[key]
if key not in self._mapping: if key not in self._mapping:
@ -883,22 +888,22 @@ class _LazyConfigMapping(OrderedDict):
transformers_module = importlib.import_module("transformers") transformers_module = importlib.import_module("transformers")
return getattr(transformers_module, value) return getattr(transformers_module, value)
def keys(self): def keys(self) -> list[str]:
return list(self._mapping.keys()) + list(self._extra_content.keys()) return list(self._mapping.keys()) + list(self._extra_content.keys())
def values(self): def values(self) -> list[type[PretrainedConfig]]:
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()) return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
def items(self): def items(self) -> list[tuple[str, type[PretrainedConfig]]]:
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()) return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
def __iter__(self): def __iter__(self) -> Iterator[str]:
return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
def __contains__(self, item): def __contains__(self, item: object) -> bool:
return item in self._mapping or item in self._extra_content return item in self._mapping or item in self._extra_content
def register(self, key, value, exist_ok=False): def register(self, key: str, value: type[PretrainedConfig], exist_ok=False) -> None:
""" """
Register a new configuration in this mapping. Register a new configuration in this mapping.
""" """
@ -910,7 +915,7 @@ class _LazyConfigMapping(OrderedDict):
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES) CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
class _LazyLoadAllMappings(OrderedDict): class _LazyLoadAllMappings(OrderedDict[str, str]):
""" """
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values, A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
etc.) etc.)
@ -940,28 +945,28 @@ class _LazyLoadAllMappings(OrderedDict):
self._initialize() self._initialize()
return self._data[key] return self._data[key]
def keys(self): def keys(self) -> KeysView[str]:
self._initialize() self._initialize()
return self._data.keys() return self._data.keys()
def values(self): def values(self) -> ValuesView[str]:
self._initialize() self._initialize()
return self._data.values() return self._data.values()
def items(self): def items(self) -> KeysView[str]:
self._initialize() self._initialize()
return self._data.keys() return self._data.keys()
def __iter__(self): def __iter__(self) -> Iterator[str]:
self._initialize() self._initialize()
return iter(self._data) return iter(self._data)
def __contains__(self, item): def __contains__(self, item: object) -> bool:
self._initialize() self._initialize()
return item in self._data return item in self._data
def _get_class_name(model_class: Union[str, List[str]]): def _get_class_name(model_class: Union[str, list[str]]):
if isinstance(model_class, (list, tuple)): if isinstance(model_class, (list, tuple)):
return " or ".join([f"[`{c}`]" for c in model_class if c is not None]) return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
return f"[`{model_class}`]" return f"[`{model_class}`]"
@ -1000,7 +1005,9 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
return "\n".join(lines) return "\n".join(lines)
def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True): def replace_list_option_in_docstrings(
config_to_class=None, use_model_types: bool = True
) -> Callable[[_CallableT], _CallableT]:
def docstring_decorator(fn): def docstring_decorator(fn):
docstrings = fn.__doc__ docstrings = fn.__doc__
if docstrings is None: if docstrings is None:
@ -1035,14 +1042,14 @@ class AutoConfig:
This class cannot be instantiated directly using `__init__()` (throws an error). This class cannot be instantiated directly using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self) -> None:
raise EnvironmentError( raise EnvironmentError(
"AutoConfig is designed to be instantiated " "AutoConfig is designed to be instantiated "
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
) )
@classmethod @classmethod
def for_model(cls, model_type: str, *args, **kwargs): def for_model(cls, model_type: str, *args, **kwargs) -> PretrainedConfig:
if model_type in CONFIG_MAPPING: if model_type in CONFIG_MAPPING:
config_class = CONFIG_MAPPING[model_type] config_class = CONFIG_MAPPING[model_type]
return config_class(*args, **kwargs) return config_class(*args, **kwargs)
@ -1052,7 +1059,7 @@ class AutoConfig:
@classmethod @classmethod
@replace_list_option_in_docstrings() @replace_list_option_in_docstrings()
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], **kwargs):
r""" r"""
Instantiate one of the configuration classes of the library from a pretrained model configuration. Instantiate one of the configuration classes of the library from a pretrained model configuration.
@ -1199,7 +1206,7 @@ class AutoConfig:
) )
@staticmethod @staticmethod
def register(model_type, config, exist_ok=False): def register(model_type, config, exist_ok=False) -> None:
""" """
Register a new configuration for this class. Register a new configuration for this class.

View File

@ -19,7 +19,7 @@ import json
import os import os
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import Any, Optional, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
@ -53,12 +53,8 @@ else:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if TYPE_CHECKING: # Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
# This significantly improves completion suggestion performance when TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
# the transformers package is used with Microsoft's Pylance language server.
TOKENIZER_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
else:
TOKENIZER_MAPPING_NAMES = OrderedDict(
[ [
( (
"albert", "albert",
@ -667,7 +663,7 @@ TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAM
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()} CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
def tokenizer_class_from_name(class_name: str): def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
if class_name == "PreTrainedTokenizerFast": if class_name == "PreTrainedTokenizerFast":
return PreTrainedTokenizerFast return PreTrainedTokenizerFast
@ -696,17 +692,17 @@ def tokenizer_class_from_name(class_name: str):
def get_tokenizer_config( def get_tokenizer_config(
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike[str]],
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
force_download: bool = False, force_download: bool = False,
resume_download: Optional[bool] = None, resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None, proxies: Optional[dict[str, str]] = None,
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
subfolder: str = "", subfolder: str = "",
**kwargs, **kwargs,
): ) -> dict[str, Any]:
""" """
Loads the tokenizer configuration from a pretrained model tokenizer configuration. Loads the tokenizer configuration from a pretrained model tokenizer configuration.
@ -728,7 +724,7 @@ def get_tokenizer_config(
resume_download: resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible. Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers. Will be removed in v5 of Transformers.
proxies (`Dict[str, str]`, *optional*): proxies (`dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
@ -751,7 +747,7 @@ def get_tokenizer_config(
</Tip> </Tip>
Returns: Returns:
`Dict`: The configuration of the tokenizer. `dict`: The configuration of the tokenizer.
Examples: Examples:
@ -855,7 +851,7 @@ class AutoTokenizer:
resume_download: resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible. Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers. Will be removed in v5 of Transformers.
proxies (`Dict[str, str]`, *optional*): proxies (`dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):