diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 54fa7c7e267..b5fe34e1414 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -18,7 +18,7 @@ import copy import json import os import warnings -from typing import Any, Optional, Union +from typing import Any, Optional, TypeVar, Union from packaging import version @@ -42,6 +42,10 @@ from .utils.generic import is_timm_config_dict logger = logging.get_logger(__name__) +# type hinting: specifying the type of config class that inherits from PretrainedConfig +SpecificPretrainedConfigType = TypeVar("SpecificPretrainedConfigType", bound="PretrainedConfig") + + class PretrainedConfig(PushToHubMixin): # no-format r""" @@ -191,7 +195,7 @@ class PretrainedConfig(PushToHubMixin): model_type: str = "" base_config_key: str = "" - sub_configs: dict[str, "PretrainedConfig"] = {} + sub_configs: dict[str, type["PretrainedConfig"]] = {} has_no_defaults_at_init: bool = False attribute_map: dict[str, str] = {} base_model_tp_plan: Optional[dict[str, Any]] = None @@ -474,7 +478,7 @@ class PretrainedConfig(PushToHubMixin): @classmethod def from_pretrained( - cls, + cls: type[SpecificPretrainedConfigType], pretrained_model_name_or_path: Union[str, os.PathLike], cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, @@ -482,7 +486,7 @@ class PretrainedConfig(PushToHubMixin): token: Optional[Union[str, bool]] = None, revision: str = "main", **kwargs, - ) -> "PretrainedConfig": + ) -> SpecificPretrainedConfigType: r""" Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. @@ -717,7 +721,9 @@ class PretrainedConfig(PushToHubMixin): return config_dict, kwargs @classmethod - def from_dict(cls, config_dict: dict[str, Any], **kwargs) -> "PretrainedConfig": + def from_dict( + cls: type[SpecificPretrainedConfigType], config_dict: dict[str, Any], **kwargs + ) -> SpecificPretrainedConfigType: """ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. @@ -778,7 +784,9 @@ class PretrainedConfig(PushToHubMixin): return config @classmethod - def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig": + def from_json_file( + cls: type[SpecificPretrainedConfigType], json_file: Union[str, os.PathLike] + ) -> SpecificPretrainedConfigType: """ Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters. diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 0d08f225e5e..c539288cbed 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -20,7 +20,7 @@ import json import os import warnings from collections import UserDict -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union import numpy as np @@ -55,6 +55,9 @@ logger = logging.get_logger(__name__) PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821 +# type hinting: specifying the type of feature extractor class that inherits from FeatureExtractionMixin +SpecificFeatureExtractorType = TypeVar("SpecificFeatureExtractorType", bound="FeatureExtractionMixin") + class BatchFeature(UserDict): r""" @@ -270,7 +273,7 @@ class FeatureExtractionMixin(PushToHubMixin): @classmethod def from_pretrained( - cls, + cls: type[SpecificFeatureExtractorType], pretrained_model_name_or_path: Union[str, os.PathLike], cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, @@ -278,7 +281,7 @@ class FeatureExtractionMixin(PushToHubMixin): token: Optional[Union[str, bool]] = None, revision: str = "main", **kwargs, - ): + ) -> SpecificFeatureExtractorType: r""" Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a derived class of [`SequenceFeatureExtractor`]. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f7f34c75ea5..fc0a249850c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2616,7 +2616,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi return config @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): """ Checks the availability of SDPA for a given model. diff --git a/src/transformers/models/clvp/configuration_clvp.py b/src/transformers/models/clvp/configuration_clvp.py index d1ee5c9fb79..e5d9957cbd4 100644 --- a/src/transformers/models/clvp/configuration_clvp.py +++ b/src/transformers/models/clvp/configuration_clvp.py @@ -131,7 +131,7 @@ class ClvpEncoderConfig(PretrainedConfig): @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], config_type: str = "text_config", **kwargs - ) -> "PretrainedConfig": + ): cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) diff --git a/src/transformers/models/deprecated/jukebox/configuration_jukebox.py b/src/transformers/models/deprecated/jukebox/configuration_jukebox.py index 2088a695baf..067d61fe2e0 100644 --- a/src/transformers/models/deprecated/jukebox/configuration_jukebox.py +++ b/src/transformers/models/deprecated/jukebox/configuration_jukebox.py @@ -345,9 +345,7 @@ class JukeboxPriorConfig(PretrainedConfig): self.zero_out = zero_out @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs - ) -> "PretrainedConfig": + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs): cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) @@ -470,7 +468,7 @@ class JukeboxVQVAEConfig(PretrainedConfig): self.zero_out = zero_out @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d5924567922..a3606e0b11c 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -15,7 +15,7 @@ """PyTorch Falcon model.""" import math -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, Union import torch import torch.utils.checkpoint @@ -47,9 +47,6 @@ from ...utils import ( from .configuration_falcon import FalconConfig -if TYPE_CHECKING: - from ...configuration_utils import PretrainedConfig - if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -688,7 +685,7 @@ class FalconPreTrainedModel(PreTrainedModel): # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig": + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): _is_bettertransformer = getattr(cls, "use_bettertransformer", False) if _is_bettertransformer: return config diff --git a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py index 0aa4490e60b..51897b40681 100644 --- a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -1074,7 +1074,7 @@ class Qwen2_5OmniConfig(PretrainedConfig): super().__init__(**kwargs) - def get_text_config(self, decoder=False) -> "PretrainedConfig": + def get_text_config(self, decoder=False): """ Returns the config that is meant to be used with text IO. On most models, it is the original config instance itself. On specific composite models, it is under a set of valid names. diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 645a5fb837f..b3d4ae90e88 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1114,7 +1114,7 @@ class Qwen2_5OmniConfig(PretrainedConfig): super().__init__(**kwargs) - def get_text_config(self, decoder=False) -> "PretrainedConfig": + def get_text_config(self, decoder=False): """ Returns the config that is meant to be used with text IO. On most models, it is the original config instance itself. On specific composite models, it is under a set of valid names. diff --git a/src/transformers/models/t5gemma/configuration_t5gemma.py b/src/transformers/models/t5gemma/configuration_t5gemma.py index b3aa23d0bec..ad4213d3fb0 100644 --- a/src/transformers/models/t5gemma/configuration_t5gemma.py +++ b/src/transformers/models/t5gemma/configuration_t5gemma.py @@ -324,7 +324,7 @@ class T5GemmaConfig(PretrainedConfig): setattr(self.decoder, key, value) super().__setattr__(key, value) - def get_text_config(self, decoder=False) -> "PretrainedConfig": + def get_text_config(self, decoder=False): # Always return self, regardless of the decoder option. del decoder return self diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index ae69ae99100..970816fc38f 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -213,7 +213,7 @@ class T5GemmaConfig(PretrainedConfig): setattr(self.decoder, key, value) super().__setattr__(key, value) - def get_text_config(self, decoder=False) -> "PretrainedConfig": + def get_text_config(self, decoder=False): # Always return self, regardless of the decoder option. del decoder return self diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 9dd9d9ce008..25add6dbe75 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -24,7 +24,7 @@ import typing import warnings from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, TypedDict, Union +from typing import Any, Optional, TypedDict, TypeVar, Union import numpy as np import typing_extensions @@ -75,6 +75,9 @@ if is_torch_available(): logger = logging.get_logger(__name__) +# type hinting: specifying the type of processor class that inherits from ProcessorMixin +SpecificProcessorType = TypeVar("SpecificProcessorType", bound="ProcessorMixin") + # Dynamically import the Transformers module to grab the attribute classes of the processor from their names. transformers_module = direct_transformers_import(Path(__file__).parent) @@ -1246,7 +1249,7 @@ class ProcessorMixin(PushToHubMixin): @classmethod def from_pretrained( - cls, + cls: type[SpecificProcessorType], pretrained_model_name_or_path: Union[str, os.PathLike], cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, @@ -1254,7 +1257,7 @@ class ProcessorMixin(PushToHubMixin): token: Optional[Union[str, bool]] = None, revision: str = "main", **kwargs, - ): + ) -> SpecificProcessorType: r""" Instantiate a processor associated with a pretrained model.