mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge type hints from microsoft/python-type-stubs
(post dropping support for Python 3.8) (#38335)
* Merge type hints from microsoft/python-type-stubs (post Python 3.8) * Remove mention of pylance * Resolved conflict * Merge type hints from microsoft/python-type-stubs (post Python 3.8) * Remove mention of pylance * Resolved conflict * Update src/transformers/models/auto/configuration_auto.py Co-authored-by: Avasam <samuel.06@hotmail.com> --------- Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
942c60956f
commit
2872e8bac5
@ -17,8 +17,11 @@
|
||||
import copy
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
@ -42,6 +45,9 @@ if is_torch_available():
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
# Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
|
||||
_LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]]
|
||||
|
||||
CLASS_DOCSTRING = """
|
||||
This is a generic model class that will be instantiated as one of the model classes of the library when created
|
||||
@ -408,7 +414,7 @@ class _BaseAutoModelClass:
|
||||
# Base class for auto models.
|
||||
_model_mapping = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
raise EnvironmentError(
|
||||
f"{self.__class__.__name__} is designed to be instantiated "
|
||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
@ -456,7 +462,7 @@ class _BaseAutoModelClass:
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs):
|
||||
config = kwargs.pop("config", None)
|
||||
trust_remote_code = kwargs.get("trust_remote_code", None)
|
||||
kwargs["_from_auto"] = True
|
||||
@ -592,7 +598,7 @@ class _BaseAutoModelClass:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register(cls, config_class, model_class, exist_ok=False):
|
||||
def register(cls, config_class, model_class, exist_ok=False) -> None:
|
||||
"""
|
||||
Register a new model for this class.
|
||||
|
||||
@ -650,7 +656,7 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass):
|
||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
|
||||
def insert_head_doc(docstring, head_doc=""):
|
||||
def insert_head_doc(docstring, head_doc: str = ""):
|
||||
if len(head_doc) > 0:
|
||||
return docstring.replace(
|
||||
"one of the model classes of the library ",
|
||||
@ -661,7 +667,7 @@ def insert_head_doc(docstring, head_doc=""):
|
||||
)
|
||||
|
||||
|
||||
def auto_class_update(cls, checkpoint_for_example="google-bert/bert-base-cased", head_doc=""):
|
||||
def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""):
|
||||
# Create a new class with the right name from the base class
|
||||
model_mapping = cls._model_mapping
|
||||
name = cls.__name__
|
||||
@ -759,7 +765,7 @@ def add_generation_mixin_to_remote_model(model_class):
|
||||
return model_class
|
||||
|
||||
|
||||
class _LazyAutoMapping(OrderedDict):
|
||||
class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue]):
|
||||
"""
|
||||
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
|
||||
|
||||
@ -768,7 +774,7 @@ class _LazyAutoMapping(OrderedDict):
|
||||
- model_mapping: The map model type to model (or tokenizer) class
|
||||
"""
|
||||
|
||||
def __init__(self, config_mapping, model_mapping):
|
||||
def __init__(self, config_mapping, model_mapping) -> None:
|
||||
self._config_mapping = config_mapping
|
||||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
||||
self._model_mapping = model_mapping
|
||||
@ -776,11 +782,11 @@ class _LazyAutoMapping(OrderedDict):
|
||||
self._extra_content = {}
|
||||
self._modules = {}
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
|
||||
return len(common_keys) + len(self._extra_content)
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: type[PretrainedConfig]) -> _LazyAutoMappingValue:
|
||||
if key in self._extra_content:
|
||||
return self._extra_content[key]
|
||||
model_type = self._reverse_config_mapping[key.__name__]
|
||||
@ -802,7 +808,7 @@ class _LazyAutoMapping(OrderedDict):
|
||||
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
|
||||
return getattribute_from_module(self._modules[module_name], attr)
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> list[type[PretrainedConfig]]:
|
||||
mapping_keys = [
|
||||
self._load_attr_from_module(key, name)
|
||||
for key, name in self._config_mapping.items()
|
||||
@ -810,16 +816,16 @@ class _LazyAutoMapping(OrderedDict):
|
||||
]
|
||||
return mapping_keys + list(self._extra_content.keys())
|
||||
|
||||
def get(self, key, default):
|
||||
def get(self, key: type[PretrainedConfig], default: _T) -> Union[_LazyAutoMappingValue, _T]:
|
||||
try:
|
||||
return self.__getitem__(key)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.keys())
|
||||
|
||||
def values(self):
|
||||
def values(self) -> list[_LazyAutoMappingValue]:
|
||||
mapping_values = [
|
||||
self._load_attr_from_module(key, name)
|
||||
for key, name in self._model_mapping.items()
|
||||
@ -827,7 +833,7 @@ class _LazyAutoMapping(OrderedDict):
|
||||
]
|
||||
return mapping_values + list(self._extra_content.values())
|
||||
|
||||
def items(self):
|
||||
def items(self) -> list[tuple[type[PretrainedConfig], _LazyAutoMappingValue]]:
|
||||
mapping_items = [
|
||||
(
|
||||
self._load_attr_from_module(key, self._config_mapping[key]),
|
||||
@ -838,10 +844,10 @@ class _LazyAutoMapping(OrderedDict):
|
||||
]
|
||||
return mapping_items + list(self._extra_content.items())
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[type[PretrainedConfig]]:
|
||||
return iter(self.keys())
|
||||
|
||||
def __contains__(self, item):
|
||||
def __contains__(self, item: type) -> bool:
|
||||
if item in self._extra_content:
|
||||
return True
|
||||
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
||||
@ -849,7 +855,7 @@ class _LazyAutoMapping(OrderedDict):
|
||||
model_type = self._reverse_config_mapping[item.__name__]
|
||||
return model_type in self._model_mapping
|
||||
|
||||
def register(self, key, value, exist_ok=False):
|
||||
def register(self, key: type[PretrainedConfig], value: _LazyAutoMappingValue, exist_ok=False) -> None:
|
||||
"""
|
||||
Register a new model in this mapping.
|
||||
"""
|
||||
|
@ -15,10 +15,12 @@
|
||||
"""Auto Config class."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import List, Union
|
||||
from collections.abc import Callable, Iterator, KeysView, ValuesView
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
@ -28,7 +30,10 @@ from ...utils import CONFIG_NAME, logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
[
|
||||
# Add configs here
|
||||
("albert", "AlbertConfig"),
|
||||
@ -380,7 +385,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
)
|
||||
|
||||
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("albert", "ALBERT"),
|
||||
@ -795,7 +800,7 @@ DEPRECATED_MODELS = [
|
||||
"xlm_prophetnet",
|
||||
]
|
||||
|
||||
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
||||
[
|
||||
("openai-gpt", "openai"),
|
||||
("data2vec-audio", "data2vec"),
|
||||
@ -827,7 +832,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
)
|
||||
|
||||
|
||||
def model_type_to_module_name(key):
|
||||
def model_type_to_module_name(key) -> str:
|
||||
"""Converts a config key to the corresponding module."""
|
||||
# Special treatment
|
||||
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
|
||||
@ -844,7 +849,7 @@ def model_type_to_module_name(key):
|
||||
return key
|
||||
|
||||
|
||||
def config_class_to_model_type(config):
|
||||
def config_class_to_model_type(config) -> Union[str, None]:
|
||||
"""Converts a config class name to the corresponding model type"""
|
||||
for key, cls in CONFIG_MAPPING_NAMES.items():
|
||||
if cls == config:
|
||||
@ -856,17 +861,17 @@ def config_class_to_model_type(config):
|
||||
return None
|
||||
|
||||
|
||||
class _LazyConfigMapping(OrderedDict):
|
||||
class _LazyConfigMapping(OrderedDict[str, type[PretrainedConfig]]):
|
||||
"""
|
||||
A dictionary that lazily load its values when they are requested.
|
||||
"""
|
||||
|
||||
def __init__(self, mapping):
|
||||
def __init__(self, mapping) -> None:
|
||||
self._mapping = mapping
|
||||
self._extra_content = {}
|
||||
self._modules = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> type[PretrainedConfig]:
|
||||
if key in self._extra_content:
|
||||
return self._extra_content[key]
|
||||
if key not in self._mapping:
|
||||
@ -883,22 +888,22 @@ class _LazyConfigMapping(OrderedDict):
|
||||
transformers_module = importlib.import_module("transformers")
|
||||
return getattr(transformers_module, value)
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> list[str]:
|
||||
return list(self._mapping.keys()) + list(self._extra_content.keys())
|
||||
|
||||
def values(self):
|
||||
def values(self) -> list[type[PretrainedConfig]]:
|
||||
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
|
||||
|
||||
def items(self):
|
||||
def items(self) -> list[tuple[str, type[PretrainedConfig]]]:
|
||||
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
|
||||
def __contains__(self, item):
|
||||
def __contains__(self, item: object) -> bool:
|
||||
return item in self._mapping or item in self._extra_content
|
||||
|
||||
def register(self, key, value, exist_ok=False):
|
||||
def register(self, key: str, value: type[PretrainedConfig], exist_ok=False) -> None:
|
||||
"""
|
||||
Register a new configuration in this mapping.
|
||||
"""
|
||||
@ -910,7 +915,7 @@ class _LazyConfigMapping(OrderedDict):
|
||||
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
|
||||
|
||||
class _LazyLoadAllMappings(OrderedDict):
|
||||
class _LazyLoadAllMappings(OrderedDict[str, str]):
|
||||
"""
|
||||
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
|
||||
etc.)
|
||||
@ -940,28 +945,28 @@ class _LazyLoadAllMappings(OrderedDict):
|
||||
self._initialize()
|
||||
return self._data[key]
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> KeysView[str]:
|
||||
self._initialize()
|
||||
return self._data.keys()
|
||||
|
||||
def values(self):
|
||||
def values(self) -> ValuesView[str]:
|
||||
self._initialize()
|
||||
return self._data.values()
|
||||
|
||||
def items(self):
|
||||
def items(self) -> KeysView[str]:
|
||||
self._initialize()
|
||||
return self._data.keys()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
self._initialize()
|
||||
return iter(self._data)
|
||||
|
||||
def __contains__(self, item):
|
||||
def __contains__(self, item: object) -> bool:
|
||||
self._initialize()
|
||||
return item in self._data
|
||||
|
||||
|
||||
def _get_class_name(model_class: Union[str, List[str]]):
|
||||
def _get_class_name(model_class: Union[str, list[str]]):
|
||||
if isinstance(model_class, (list, tuple)):
|
||||
return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
|
||||
return f"[`{model_class}`]"
|
||||
@ -1000,7 +1005,9 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
|
||||
def replace_list_option_in_docstrings(
|
||||
config_to_class=None, use_model_types: bool = True
|
||||
) -> Callable[[_CallableT], _CallableT]:
|
||||
def docstring_decorator(fn):
|
||||
docstrings = fn.__doc__
|
||||
if docstrings is None:
|
||||
@ -1035,14 +1042,14 @@ class AutoConfig:
|
||||
This class cannot be instantiated directly using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
raise EnvironmentError(
|
||||
"AutoConfig is designed to be instantiated "
|
||||
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_type: str, *args, **kwargs):
|
||||
def for_model(cls, model_type: str, *args, **kwargs) -> PretrainedConfig:
|
||||
if model_type in CONFIG_MAPPING:
|
||||
config_class = CONFIG_MAPPING[model_type]
|
||||
return config_class(*args, **kwargs)
|
||||
@ -1052,7 +1059,7 @@ class AutoConfig:
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings()
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], **kwargs):
|
||||
r"""
|
||||
Instantiate one of the configuration classes of the library from a pretrained model configuration.
|
||||
|
||||
@ -1199,7 +1206,7 @@ class AutoConfig:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register(model_type, config, exist_ok=False):
|
||||
def register(model_type, config, exist_ok=False) -> None:
|
||||
"""
|
||||
Register a new configuration for this class.
|
||||
|
||||
|
@ -19,7 +19,7 @@ import json
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
@ -53,12 +53,8 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# This significantly improves completion suggestion performance when
|
||||
# the transformers package is used with Microsoft's Pylance language server.
|
||||
TOKENIZER_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
|
||||
else:
|
||||
TOKENIZER_MAPPING_NAMES = OrderedDict(
|
||||
# Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
|
||||
TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
[
|
||||
(
|
||||
"albert",
|
||||
@ -667,7 +663,7 @@ TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAM
|
||||
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
|
||||
|
||||
|
||||
def tokenizer_class_from_name(class_name: str):
|
||||
def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
|
||||
if class_name == "PreTrainedTokenizerFast":
|
||||
return PreTrainedTokenizerFast
|
||||
|
||||
@ -696,17 +692,17 @@ def tokenizer_class_from_name(class_name: str):
|
||||
|
||||
|
||||
def get_tokenizer_config(
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
|
||||
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
proxies: Optional[dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
|
||||
|
||||
@ -728,7 +724,7 @@ def get_tokenizer_config(
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
token (`str` or *bool*, *optional*):
|
||||
@ -751,7 +747,7 @@ def get_tokenizer_config(
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`Dict`: The configuration of the tokenizer.
|
||||
`dict`: The configuration of the tokenizer.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -855,7 +851,7 @@ class AutoTokenizer:
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
|
Loading…
Reference in New Issue
Block a user