This commit is contained in:
Pavel Iakubovskii 2025-07-02 20:19:43 +02:00 committed by GitHub
commit f28ef6ee50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 36 additions and 27 deletions

View File

@ -18,7 +18,7 @@ import copy
import json import json
import os import os
import warnings import warnings
from typing import Any, Optional, Union from typing import Any, Optional, TypeVar, Union
from packaging import version from packaging import version
@ -42,6 +42,10 @@ from .utils.generic import is_timm_config_dict
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# type hinting: specifying the type of config class that inherits from PretrainedConfig
SpecificPretrainedConfigType = TypeVar("SpecificPretrainedConfigType", bound="PretrainedConfig")
class PretrainedConfig(PushToHubMixin): class PretrainedConfig(PushToHubMixin):
# no-format # no-format
r""" r"""
@ -191,7 +195,7 @@ class PretrainedConfig(PushToHubMixin):
model_type: str = "" model_type: str = ""
base_config_key: str = "" base_config_key: str = ""
sub_configs: dict[str, "PretrainedConfig"] = {} sub_configs: dict[str, type["PretrainedConfig"]] = {}
has_no_defaults_at_init: bool = False has_no_defaults_at_init: bool = False
attribute_map: dict[str, str] = {} attribute_map: dict[str, str] = {}
base_model_tp_plan: Optional[dict[str, Any]] = None base_model_tp_plan: Optional[dict[str, Any]] = None
@ -474,7 +478,7 @@ class PretrainedConfig(PushToHubMixin):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls: type[SpecificPretrainedConfigType],
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
@ -482,7 +486,7 @@ class PretrainedConfig(PushToHubMixin):
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
revision: str = "main", revision: str = "main",
**kwargs, **kwargs,
) -> "PretrainedConfig": ) -> SpecificPretrainedConfigType:
r""" r"""
Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
@ -717,7 +721,9 @@ class PretrainedConfig(PushToHubMixin):
return config_dict, kwargs return config_dict, kwargs
@classmethod @classmethod
def from_dict(cls, config_dict: dict[str, Any], **kwargs) -> "PretrainedConfig": def from_dict(
cls: type[SpecificPretrainedConfigType], config_dict: dict[str, Any], **kwargs
) -> SpecificPretrainedConfigType:
""" """
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
@ -778,7 +784,9 @@ class PretrainedConfig(PushToHubMixin):
return config return config
@classmethod @classmethod
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig": def from_json_file(
cls: type[SpecificPretrainedConfigType], json_file: Union[str, os.PathLike]
) -> SpecificPretrainedConfigType:
""" """
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters. Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.

View File

@ -20,7 +20,7 @@ import json
import os import os
import warnings import warnings
from collections import UserDict from collections import UserDict
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
import numpy as np import numpy as np
@ -55,6 +55,9 @@ logger = logging.get_logger(__name__)
PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821 PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821
# type hinting: specifying the type of feature extractor class that inherits from FeatureExtractionMixin
SpecificFeatureExtractorType = TypeVar("SpecificFeatureExtractorType", bound="FeatureExtractionMixin")
class BatchFeature(UserDict): class BatchFeature(UserDict):
r""" r"""
@ -270,7 +273,7 @@ class FeatureExtractionMixin(PushToHubMixin):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls: type[SpecificFeatureExtractorType],
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
@ -278,7 +281,7 @@ class FeatureExtractionMixin(PushToHubMixin):
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
revision: str = "main", revision: str = "main",
**kwargs, **kwargs,
): ) -> SpecificFeatureExtractorType:
r""" r"""
Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
derived class of [`SequenceFeatureExtractor`]. derived class of [`SequenceFeatureExtractor`].

View File

@ -2616,7 +2616,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
return config return config
@classmethod @classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
""" """
Checks the availability of SDPA for a given model. Checks the availability of SDPA for a given model.

View File

@ -131,7 +131,7 @@ class ClvpEncoderConfig(PretrainedConfig):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_type: str = "text_config", **kwargs cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_type: str = "text_config", **kwargs
) -> "PretrainedConfig": ):
cls._set_token_in_kwargs(kwargs) cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

View File

@ -345,9 +345,7 @@ class JukeboxPriorConfig(PretrainedConfig):
self.zero_out = zero_out self.zero_out = zero_out
@classmethod @classmethod
def from_pretrained( def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs):
cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs) cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
@ -470,7 +468,7 @@ class JukeboxVQVAEConfig(PretrainedConfig):
self.zero_out = zero_out self.zero_out = zero_out
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs):
cls._set_token_in_kwargs(kwargs) cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

View File

@ -15,7 +15,7 @@
"""PyTorch Falcon model.""" """PyTorch Falcon model."""
import math import math
from typing import TYPE_CHECKING, Optional, Union from typing import Optional, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -47,9 +47,6 @@ from ...utils import (
from .configuration_falcon import FalconConfig from .configuration_falcon import FalconConfig
if TYPE_CHECKING:
from ...configuration_utils import PretrainedConfig
if is_flash_attn_available(): if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_flash_attention_utils import _flash_attention_forward
@ -688,7 +685,7 @@ class FalconPreTrainedModel(PreTrainedModel):
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod @classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig": def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
_is_bettertransformer = getattr(cls, "use_bettertransformer", False) _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer: if _is_bettertransformer:
return config return config

View File

@ -1074,7 +1074,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
super().__init__(**kwargs) super().__init__(**kwargs)
def get_text_config(self, decoder=False) -> "PretrainedConfig": def get_text_config(self, decoder=False):
""" """
Returns the config that is meant to be used with text IO. On most models, it is the original config instance Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names. itself. On specific composite models, it is under a set of valid names.

View File

@ -1114,7 +1114,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
super().__init__(**kwargs) super().__init__(**kwargs)
def get_text_config(self, decoder=False) -> "PretrainedConfig": def get_text_config(self, decoder=False):
""" """
Returns the config that is meant to be used with text IO. On most models, it is the original config instance Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names. itself. On specific composite models, it is under a set of valid names.

View File

@ -324,7 +324,7 @@ class T5GemmaConfig(PretrainedConfig):
setattr(self.decoder, key, value) setattr(self.decoder, key, value)
super().__setattr__(key, value) super().__setattr__(key, value)
def get_text_config(self, decoder=False) -> "PretrainedConfig": def get_text_config(self, decoder=False):
# Always return self, regardless of the decoder option. # Always return self, regardless of the decoder option.
del decoder del decoder
return self return self

View File

@ -213,7 +213,7 @@ class T5GemmaConfig(PretrainedConfig):
setattr(self.decoder, key, value) setattr(self.decoder, key, value)
super().__setattr__(key, value) super().__setattr__(key, value)
def get_text_config(self, decoder=False) -> "PretrainedConfig": def get_text_config(self, decoder=False):
# Always return self, regardless of the decoder option. # Always return self, regardless of the decoder option.
del decoder del decoder
return self return self

View File

@ -24,7 +24,7 @@ import typing
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Optional, TypedDict, Union from typing import Any, Optional, TypedDict, TypeVar, Union
import numpy as np import numpy as np
import typing_extensions import typing_extensions
@ -75,6 +75,9 @@ if is_torch_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# type hinting: specifying the type of processor class that inherits from ProcessorMixin
SpecificProcessorType = TypeVar("SpecificProcessorType", bound="ProcessorMixin")
# Dynamically import the Transformers module to grab the attribute classes of the processor from their names. # Dynamically import the Transformers module to grab the attribute classes of the processor from their names.
transformers_module = direct_transformers_import(Path(__file__).parent) transformers_module = direct_transformers_import(Path(__file__).parent)
@ -1246,7 +1249,7 @@ class ProcessorMixin(PushToHubMixin):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls: type[SpecificProcessorType],
pretrained_model_name_or_path: Union[str, os.PathLike], pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
@ -1254,7 +1257,7 @@ class ProcessorMixin(PushToHubMixin):
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
revision: str = "main", revision: str = "main",
**kwargs, **kwargs,
): ) -> SpecificProcessorType:
r""" r"""
Instantiate a processor associated with a pretrained model. Instantiate a processor associated with a pretrained model.