mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Merge type hints from microsoft/python-type-stubs
(post dropping support for Python 3.8) (#38335)
* Merge type hints from microsoft/python-type-stubs (post Python 3.8) * Remove mention of pylance * Resolved conflict * Merge type hints from microsoft/python-type-stubs (post Python 3.8) * Remove mention of pylance * Resolved conflict * Update src/transformers/models/auto/configuration_auto.py Co-authored-by: Avasam <samuel.06@hotmail.com> --------- Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
parent
942c60956f
commit
2872e8bac5
@ -17,8 +17,11 @@
|
|||||||
import copy
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from typing import Any, TypeVar, Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||||
@ -42,6 +45,9 @@ if is_torch_available():
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
# Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol
|
||||||
|
_LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]]
|
||||||
|
|
||||||
CLASS_DOCSTRING = """
|
CLASS_DOCSTRING = """
|
||||||
This is a generic model class that will be instantiated as one of the model classes of the library when created
|
This is a generic model class that will be instantiated as one of the model classes of the library when created
|
||||||
@ -408,7 +414,7 @@ class _BaseAutoModelClass:
|
|||||||
# Base class for auto models.
|
# Base class for auto models.
|
||||||
_model_mapping = None
|
_model_mapping = None
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{self.__class__.__name__} is designed to be instantiated "
|
f"{self.__class__.__name__} is designed to be instantiated "
|
||||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
@ -456,7 +462,7 @@ class _BaseAutoModelClass:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs):
|
||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
trust_remote_code = kwargs.get("trust_remote_code", None)
|
trust_remote_code = kwargs.get("trust_remote_code", None)
|
||||||
kwargs["_from_auto"] = True
|
kwargs["_from_auto"] = True
|
||||||
@ -592,7 +598,7 @@ class _BaseAutoModelClass:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, config_class, model_class, exist_ok=False):
|
def register(cls, config_class, model_class, exist_ok=False) -> None:
|
||||||
"""
|
"""
|
||||||
Register a new model for this class.
|
Register a new model for this class.
|
||||||
|
|
||||||
@ -650,7 +656,7 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass):
|
|||||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def insert_head_doc(docstring, head_doc=""):
|
def insert_head_doc(docstring, head_doc: str = ""):
|
||||||
if len(head_doc) > 0:
|
if len(head_doc) > 0:
|
||||||
return docstring.replace(
|
return docstring.replace(
|
||||||
"one of the model classes of the library ",
|
"one of the model classes of the library ",
|
||||||
@ -661,7 +667,7 @@ def insert_head_doc(docstring, head_doc=""):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def auto_class_update(cls, checkpoint_for_example="google-bert/bert-base-cased", head_doc=""):
|
def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""):
|
||||||
# Create a new class with the right name from the base class
|
# Create a new class with the right name from the base class
|
||||||
model_mapping = cls._model_mapping
|
model_mapping = cls._model_mapping
|
||||||
name = cls.__name__
|
name = cls.__name__
|
||||||
@ -759,7 +765,7 @@ def add_generation_mixin_to_remote_model(model_class):
|
|||||||
return model_class
|
return model_class
|
||||||
|
|
||||||
|
|
||||||
class _LazyAutoMapping(OrderedDict):
|
class _LazyAutoMapping(OrderedDict[type[PretrainedConfig], _LazyAutoMappingValue]):
|
||||||
"""
|
"""
|
||||||
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
|
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
|
||||||
|
|
||||||
@ -768,7 +774,7 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
- model_mapping: The map model type to model (or tokenizer) class
|
- model_mapping: The map model type to model (or tokenizer) class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config_mapping, model_mapping):
|
def __init__(self, config_mapping, model_mapping) -> None:
|
||||||
self._config_mapping = config_mapping
|
self._config_mapping = config_mapping
|
||||||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
||||||
self._model_mapping = model_mapping
|
self._model_mapping = model_mapping
|
||||||
@ -776,11 +782,11 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
self._extra_content = {}
|
self._extra_content = {}
|
||||||
self._modules = {}
|
self._modules = {}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
|
common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
|
||||||
return len(common_keys) + len(self._extra_content)
|
return len(common_keys) + len(self._extra_content)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: type[PretrainedConfig]) -> _LazyAutoMappingValue:
|
||||||
if key in self._extra_content:
|
if key in self._extra_content:
|
||||||
return self._extra_content[key]
|
return self._extra_content[key]
|
||||||
model_type = self._reverse_config_mapping[key.__name__]
|
model_type = self._reverse_config_mapping[key.__name__]
|
||||||
@ -802,7 +808,7 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
|
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
|
||||||
return getattribute_from_module(self._modules[module_name], attr)
|
return getattribute_from_module(self._modules[module_name], attr)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self) -> list[type[PretrainedConfig]]:
|
||||||
mapping_keys = [
|
mapping_keys = [
|
||||||
self._load_attr_from_module(key, name)
|
self._load_attr_from_module(key, name)
|
||||||
for key, name in self._config_mapping.items()
|
for key, name in self._config_mapping.items()
|
||||||
@ -810,16 +816,16 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
]
|
]
|
||||||
return mapping_keys + list(self._extra_content.keys())
|
return mapping_keys + list(self._extra_content.keys())
|
||||||
|
|
||||||
def get(self, key, default):
|
def get(self, key: type[PretrainedConfig], default: _T) -> Union[_LazyAutoMappingValue, _T]:
|
||||||
try:
|
try:
|
||||||
return self.__getitem__(key)
|
return self.__getitem__(key)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self) -> bool:
|
||||||
return bool(self.keys())
|
return bool(self.keys())
|
||||||
|
|
||||||
def values(self):
|
def values(self) -> list[_LazyAutoMappingValue]:
|
||||||
mapping_values = [
|
mapping_values = [
|
||||||
self._load_attr_from_module(key, name)
|
self._load_attr_from_module(key, name)
|
||||||
for key, name in self._model_mapping.items()
|
for key, name in self._model_mapping.items()
|
||||||
@ -827,7 +833,7 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
]
|
]
|
||||||
return mapping_values + list(self._extra_content.values())
|
return mapping_values + list(self._extra_content.values())
|
||||||
|
|
||||||
def items(self):
|
def items(self) -> list[tuple[type[PretrainedConfig], _LazyAutoMappingValue]]:
|
||||||
mapping_items = [
|
mapping_items = [
|
||||||
(
|
(
|
||||||
self._load_attr_from_module(key, self._config_mapping[key]),
|
self._load_attr_from_module(key, self._config_mapping[key]),
|
||||||
@ -838,10 +844,10 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
]
|
]
|
||||||
return mapping_items + list(self._extra_content.items())
|
return mapping_items + list(self._extra_content.items())
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator[type[PretrainedConfig]]:
|
||||||
return iter(self.keys())
|
return iter(self.keys())
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item: type) -> bool:
|
||||||
if item in self._extra_content:
|
if item in self._extra_content:
|
||||||
return True
|
return True
|
||||||
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
||||||
@ -849,7 +855,7 @@ class _LazyAutoMapping(OrderedDict):
|
|||||||
model_type = self._reverse_config_mapping[item.__name__]
|
model_type = self._reverse_config_mapping[item.__name__]
|
||||||
return model_type in self._model_mapping
|
return model_type in self._model_mapping
|
||||||
|
|
||||||
def register(self, key, value, exist_ok=False):
|
def register(self, key: type[PretrainedConfig], value: _LazyAutoMappingValue, exist_ok=False) -> None:
|
||||||
"""
|
"""
|
||||||
Register a new model in this mapping.
|
Register a new model in this mapping.
|
||||||
"""
|
"""
|
||||||
|
@ -15,10 +15,12 @@
|
|||||||
"""Auto Config class."""
|
"""Auto Config class."""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import List, Union
|
from collections.abc import Callable, Iterator, KeysView, ValuesView
|
||||||
|
from typing import Any, TypeVar, Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||||
@ -28,7 +30,10 @@ from ...utils import CONFIG_NAME, logging
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||||
[
|
[
|
||||||
# Add configs here
|
# Add configs here
|
||||||
("albert", "AlbertConfig"),
|
("albert", "AlbertConfig"),
|
||||||
@ -380,7 +385,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_NAMES_MAPPING = OrderedDict(
|
MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||||
[
|
[
|
||||||
# Add full (and cased) model names here
|
# Add full (and cased) model names here
|
||||||
("albert", "ALBERT"),
|
("albert", "ALBERT"),
|
||||||
@ -795,7 +800,7 @@ DEPRECATED_MODELS = [
|
|||||||
"xlm_prophetnet",
|
"xlm_prophetnet",
|
||||||
]
|
]
|
||||||
|
|
||||||
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
||||||
[
|
[
|
||||||
("openai-gpt", "openai"),
|
("openai-gpt", "openai"),
|
||||||
("data2vec-audio", "data2vec"),
|
("data2vec-audio", "data2vec"),
|
||||||
@ -827,7 +832,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def model_type_to_module_name(key):
|
def model_type_to_module_name(key) -> str:
|
||||||
"""Converts a config key to the corresponding module."""
|
"""Converts a config key to the corresponding module."""
|
||||||
# Special treatment
|
# Special treatment
|
||||||
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
|
if key in SPECIAL_MODEL_TYPE_TO_MODULE_NAME:
|
||||||
@ -844,7 +849,7 @@ def model_type_to_module_name(key):
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
def config_class_to_model_type(config):
|
def config_class_to_model_type(config) -> Union[str, None]:
|
||||||
"""Converts a config class name to the corresponding model type"""
|
"""Converts a config class name to the corresponding model type"""
|
||||||
for key, cls in CONFIG_MAPPING_NAMES.items():
|
for key, cls in CONFIG_MAPPING_NAMES.items():
|
||||||
if cls == config:
|
if cls == config:
|
||||||
@ -856,17 +861,17 @@ def config_class_to_model_type(config):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class _LazyConfigMapping(OrderedDict):
|
class _LazyConfigMapping(OrderedDict[str, type[PretrainedConfig]]):
|
||||||
"""
|
"""
|
||||||
A dictionary that lazily load its values when they are requested.
|
A dictionary that lazily load its values when they are requested.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mapping):
|
def __init__(self, mapping) -> None:
|
||||||
self._mapping = mapping
|
self._mapping = mapping
|
||||||
self._extra_content = {}
|
self._extra_content = {}
|
||||||
self._modules = {}
|
self._modules = {}
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: str) -> type[PretrainedConfig]:
|
||||||
if key in self._extra_content:
|
if key in self._extra_content:
|
||||||
return self._extra_content[key]
|
return self._extra_content[key]
|
||||||
if key not in self._mapping:
|
if key not in self._mapping:
|
||||||
@ -883,22 +888,22 @@ class _LazyConfigMapping(OrderedDict):
|
|||||||
transformers_module = importlib.import_module("transformers")
|
transformers_module = importlib.import_module("transformers")
|
||||||
return getattr(transformers_module, value)
|
return getattr(transformers_module, value)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self) -> list[str]:
|
||||||
return list(self._mapping.keys()) + list(self._extra_content.keys())
|
return list(self._mapping.keys()) + list(self._extra_content.keys())
|
||||||
|
|
||||||
def values(self):
|
def values(self) -> list[type[PretrainedConfig]]:
|
||||||
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
|
return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
|
||||||
|
|
||||||
def items(self):
|
def items(self) -> list[tuple[str, type[PretrainedConfig]]]:
|
||||||
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
|
return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator[str]:
|
||||||
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item: object) -> bool:
|
||||||
return item in self._mapping or item in self._extra_content
|
return item in self._mapping or item in self._extra_content
|
||||||
|
|
||||||
def register(self, key, value, exist_ok=False):
|
def register(self, key: str, value: type[PretrainedConfig], exist_ok=False) -> None:
|
||||||
"""
|
"""
|
||||||
Register a new configuration in this mapping.
|
Register a new configuration in this mapping.
|
||||||
"""
|
"""
|
||||||
@ -910,7 +915,7 @@ class _LazyConfigMapping(OrderedDict):
|
|||||||
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||||
|
|
||||||
|
|
||||||
class _LazyLoadAllMappings(OrderedDict):
|
class _LazyLoadAllMappings(OrderedDict[str, str]):
|
||||||
"""
|
"""
|
||||||
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
|
A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
|
||||||
etc.)
|
etc.)
|
||||||
@ -940,28 +945,28 @@ class _LazyLoadAllMappings(OrderedDict):
|
|||||||
self._initialize()
|
self._initialize()
|
||||||
return self._data[key]
|
return self._data[key]
|
||||||
|
|
||||||
def keys(self):
|
def keys(self) -> KeysView[str]:
|
||||||
self._initialize()
|
self._initialize()
|
||||||
return self._data.keys()
|
return self._data.keys()
|
||||||
|
|
||||||
def values(self):
|
def values(self) -> ValuesView[str]:
|
||||||
self._initialize()
|
self._initialize()
|
||||||
return self._data.values()
|
return self._data.values()
|
||||||
|
|
||||||
def items(self):
|
def items(self) -> KeysView[str]:
|
||||||
self._initialize()
|
self._initialize()
|
||||||
return self._data.keys()
|
return self._data.keys()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator[str]:
|
||||||
self._initialize()
|
self._initialize()
|
||||||
return iter(self._data)
|
return iter(self._data)
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item: object) -> bool:
|
||||||
self._initialize()
|
self._initialize()
|
||||||
return item in self._data
|
return item in self._data
|
||||||
|
|
||||||
|
|
||||||
def _get_class_name(model_class: Union[str, List[str]]):
|
def _get_class_name(model_class: Union[str, list[str]]):
|
||||||
if isinstance(model_class, (list, tuple)):
|
if isinstance(model_class, (list, tuple)):
|
||||||
return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
|
return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
|
||||||
return f"[`{model_class}`]"
|
return f"[`{model_class}`]"
|
||||||
@ -1000,7 +1005,9 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
|
|||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
|
def replace_list_option_in_docstrings(
|
||||||
|
config_to_class=None, use_model_types: bool = True
|
||||||
|
) -> Callable[[_CallableT], _CallableT]:
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
docstrings = fn.__doc__
|
docstrings = fn.__doc__
|
||||||
if docstrings is None:
|
if docstrings is None:
|
||||||
@ -1035,14 +1042,14 @@ class AutoConfig:
|
|||||||
This class cannot be instantiated directly using `__init__()` (throws an error).
|
This class cannot be instantiated directly using `__init__()` (throws an error).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
"AutoConfig is designed to be instantiated "
|
"AutoConfig is designed to be instantiated "
|
||||||
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_model(cls, model_type: str, *args, **kwargs):
|
def for_model(cls, model_type: str, *args, **kwargs) -> PretrainedConfig:
|
||||||
if model_type in CONFIG_MAPPING:
|
if model_type in CONFIG_MAPPING:
|
||||||
config_class = CONFIG_MAPPING[model_type]
|
config_class = CONFIG_MAPPING[model_type]
|
||||||
return config_class(*args, **kwargs)
|
return config_class(*args, **kwargs)
|
||||||
@ -1052,7 +1059,7 @@ class AutoConfig:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@replace_list_option_in_docstrings()
|
@replace_list_option_in_docstrings()
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Instantiate one of the configuration classes of the library from a pretrained model configuration.
|
Instantiate one of the configuration classes of the library from a pretrained model configuration.
|
||||||
|
|
||||||
@ -1199,7 +1206,7 @@ class AutoConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register(model_type, config, exist_ok=False):
|
def register(model_type, config, exist_ok=False) -> None:
|
||||||
"""
|
"""
|
||||||
Register a new configuration for this class.
|
Register a new configuration for this class.
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||||
@ -53,12 +53,8 @@ else:
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
# Explicit rather than inferred generics to significantly improves completion suggestion performance for language servers.
|
||||||
# This significantly improves completion suggestion performance when
|
TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||||
# the transformers package is used with Microsoft's Pylance language server.
|
|
||||||
TOKENIZER_MAPPING_NAMES: OrderedDict[str, Tuple[Optional[str], Optional[str]]] = OrderedDict()
|
|
||||||
else:
|
|
||||||
TOKENIZER_MAPPING_NAMES = OrderedDict(
|
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
"albert",
|
"albert",
|
||||||
@ -667,7 +663,7 @@ TOKENIZER_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TOKENIZER_MAPPING_NAM
|
|||||||
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
|
CONFIG_TO_TYPE = {v: k for k, v in CONFIG_MAPPING_NAMES.items()}
|
||||||
|
|
||||||
|
|
||||||
def tokenizer_class_from_name(class_name: str):
|
def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
|
||||||
if class_name == "PreTrainedTokenizerFast":
|
if class_name == "PreTrainedTokenizerFast":
|
||||||
return PreTrainedTokenizerFast
|
return PreTrainedTokenizerFast
|
||||||
|
|
||||||
@ -696,17 +692,17 @@ def tokenizer_class_from_name(class_name: str):
|
|||||||
|
|
||||||
|
|
||||||
def get_tokenizer_config(
|
def get_tokenizer_config(
|
||||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
pretrained_model_name_or_path: Union[str, os.PathLike[str]],
|
||||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
cache_dir: Optional[Union[str, os.PathLike[str]]] = None,
|
||||||
force_download: bool = False,
|
force_download: bool = False,
|
||||||
resume_download: Optional[bool] = None,
|
resume_download: Optional[bool] = None,
|
||||||
proxies: Optional[Dict[str, str]] = None,
|
proxies: Optional[dict[str, str]] = None,
|
||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
subfolder: str = "",
|
subfolder: str = "",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
|
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
|
||||||
|
|
||||||
@ -728,7 +724,7 @@ def get_tokenizer_config(
|
|||||||
resume_download:
|
resume_download:
|
||||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||||
Will be removed in v5 of Transformers.
|
Will be removed in v5 of Transformers.
|
||||||
proxies (`Dict[str, str]`, *optional*):
|
proxies (`dict[str, str]`, *optional*):
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||||
token (`str` or *bool*, *optional*):
|
token (`str` or *bool*, *optional*):
|
||||||
@ -751,7 +747,7 @@ def get_tokenizer_config(
|
|||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`Dict`: The configuration of the tokenizer.
|
`dict`: The configuration of the tokenizer.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@ -855,7 +851,7 @@ class AutoTokenizer:
|
|||||||
resume_download:
|
resume_download:
|
||||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||||
Will be removed in v5 of Transformers.
|
Will be removed in v5 of Transformers.
|
||||||
proxies (`Dict[str, str]`, *optional*):
|
proxies (`dict[str, str]`, *optional*):
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
revision (`str`, *optional*, defaults to `"main"`):
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
Loading…
Reference in New Issue
Block a user