mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Make sure all submodules are properly registered (#15144)
* Make sure all submodules are properly registered * Try to fix tests * Fix tests
This commit is contained in:
parent
c4f7eb124b
commit
7cbf8429d9
@ -51,7 +51,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
# Base objects, independent of any specific backend
|
# Base objects, independent of any specific backend
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
|
"benchmark": [],
|
||||||
|
"commands": [],
|
||||||
"configuration_utils": ["PretrainedConfig"],
|
"configuration_utils": ["PretrainedConfig"],
|
||||||
|
"convert_graph_to_onnx": [],
|
||||||
|
"convert_slow_tokenizers_checkpoints_to_fast": [],
|
||||||
|
"convert_tf_hub_seq_to_seq_bert_to_pytorch": [],
|
||||||
"data": [
|
"data": [
|
||||||
"DataProcessor",
|
"DataProcessor",
|
||||||
"InputExample",
|
"InputExample",
|
||||||
@ -84,6 +89,11 @@ _import_structure = {
|
|||||||
"DefaultDataCollator",
|
"DefaultDataCollator",
|
||||||
"default_data_collator",
|
"default_data_collator",
|
||||||
],
|
],
|
||||||
|
"data.metrics": [],
|
||||||
|
"data.processors": [],
|
||||||
|
"debug_utils": [],
|
||||||
|
"dependency_versions_check": [],
|
||||||
|
"dependency_versions_table": [],
|
||||||
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
|
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
|
||||||
"feature_extraction_utils": ["BatchFeature"],
|
"feature_extraction_utils": ["BatchFeature"],
|
||||||
"file_utils": [
|
"file_utils": [
|
||||||
@ -179,6 +189,7 @@ _import_structure = {
|
|||||||
"BlenderbotSmallConfig",
|
"BlenderbotSmallConfig",
|
||||||
"BlenderbotSmallTokenizer",
|
"BlenderbotSmallTokenizer",
|
||||||
],
|
],
|
||||||
|
"models.bort": [],
|
||||||
"models.byt5": ["ByT5Tokenizer"],
|
"models.byt5": ["ByT5Tokenizer"],
|
||||||
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
|
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
|
||||||
"models.canine": ["CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP", "CanineConfig", "CanineTokenizer"],
|
"models.canine": ["CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP", "CanineConfig", "CanineTokenizer"],
|
||||||
@ -196,6 +207,7 @@ _import_structure = {
|
|||||||
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
|
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
|
||||||
"models.deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"],
|
"models.deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"],
|
||||||
"models.detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"],
|
"models.detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"],
|
||||||
|
"models.dialogpt": [],
|
||||||
"models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"],
|
"models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"],
|
||||||
"models.dpr": [
|
"models.dpr": [
|
||||||
"DPR_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
"DPR_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
@ -236,6 +248,7 @@ _import_structure = {
|
|||||||
"models.mbart": ["MBartConfig"],
|
"models.mbart": ["MBartConfig"],
|
||||||
"models.mbart50": [],
|
"models.mbart50": [],
|
||||||
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
|
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
|
||||||
|
"models.megatron_gpt2": [],
|
||||||
"models.mluke": [],
|
"models.mluke": [],
|
||||||
"models.mmbt": ["MMBTConfig"],
|
"models.mmbt": ["MMBTConfig"],
|
||||||
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
|
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
|
||||||
@ -316,6 +329,7 @@ _import_structure = {
|
|||||||
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
||||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||||
"models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
|
"models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
|
||||||
|
"onnx": [],
|
||||||
"pipelines": [
|
"pipelines": [
|
||||||
"AudioClassificationPipeline",
|
"AudioClassificationPipeline",
|
||||||
"AutomaticSpeechRecognitionPipeline",
|
"AutomaticSpeechRecognitionPipeline",
|
||||||
@ -343,6 +357,7 @@ _import_structure = {
|
|||||||
"ZeroShotClassificationPipeline",
|
"ZeroShotClassificationPipeline",
|
||||||
"pipeline",
|
"pipeline",
|
||||||
],
|
],
|
||||||
|
"testing_utils": [],
|
||||||
"tokenization_utils": ["PreTrainedTokenizer"],
|
"tokenization_utils": ["PreTrainedTokenizer"],
|
||||||
"tokenization_utils_base": [
|
"tokenization_utils_base": [
|
||||||
"AddedToken",
|
"AddedToken",
|
||||||
@ -567,6 +582,7 @@ else:
|
|||||||
|
|
||||||
# PyTorch-backed objects
|
# PyTorch-backed objects
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
_import_structure["activations"] = []
|
||||||
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
|
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
|
||||||
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
|
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
|
||||||
_import_structure["data.datasets"] = [
|
_import_structure["data.datasets"] = [
|
||||||
@ -580,6 +596,7 @@ if is_torch_available():
|
|||||||
"TextDataset",
|
"TextDataset",
|
||||||
"TextDatasetForNextSentencePrediction",
|
"TextDatasetForNextSentencePrediction",
|
||||||
]
|
]
|
||||||
|
_import_structure["deepspeed"] = []
|
||||||
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"]
|
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"]
|
||||||
_import_structure["generation_logits_process"] = [
|
_import_structure["generation_logits_process"] = [
|
||||||
"ForcedBOSTokenLogitsProcessor",
|
"ForcedBOSTokenLogitsProcessor",
|
||||||
@ -1455,6 +1472,7 @@ if is_torch_available():
|
|||||||
"get_polynomial_decay_schedule_with_warmup",
|
"get_polynomial_decay_schedule_with_warmup",
|
||||||
"get_scheduler",
|
"get_scheduler",
|
||||||
]
|
]
|
||||||
|
_import_structure["sagemaker"] = []
|
||||||
_import_structure["trainer"] = ["Trainer"]
|
_import_structure["trainer"] = ["Trainer"]
|
||||||
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
|
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
|
||||||
_import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]
|
_import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]
|
||||||
@ -1465,6 +1483,7 @@ else:
|
|||||||
|
|
||||||
# TensorFlow-backed objects
|
# TensorFlow-backed objects
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
|
_import_structure["activations_tf"] = []
|
||||||
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
|
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
|
||||||
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
|
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
|
||||||
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
|
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
|
||||||
@ -2129,6 +2148,7 @@ else:
|
|||||||
name for name in dir(dummy_flax_objects) if not name.startswith("_")
|
name for name in dir(dummy_flax_objects) if not name.startswith("_")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Direct imports for type-checking
|
# Direct imports for type-checking
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# Configuration
|
# Configuration
|
||||||
|
@ -14,8 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||||
@ -202,5 +204,58 @@ def check_all_inits():
|
|||||||
raise ValueError("\n\n".join(failures))
|
raise ValueError("\n\n".join(failures))
|
||||||
|
|
||||||
|
|
||||||
|
def get_transformers_submodules():
|
||||||
|
"""
|
||||||
|
Returns the list of Transformers submodules.
|
||||||
|
"""
|
||||||
|
submodules = []
|
||||||
|
for path, directories, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||||
|
for folder in directories:
|
||||||
|
if folder.startswith("_"):
|
||||||
|
directories.remove(folder)
|
||||||
|
continue
|
||||||
|
short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS))
|
||||||
|
submodule = short_path.replace(os.path.sep, ".")
|
||||||
|
submodules.append(submodule)
|
||||||
|
for fname in files:
|
||||||
|
if fname == "__init__.py":
|
||||||
|
continue
|
||||||
|
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
|
||||||
|
submodule = short_path.replace(os.path.sep, ".").replace(".py", "")
|
||||||
|
if len(submodule.split(".")) == 1:
|
||||||
|
submodules.append(submodule)
|
||||||
|
return submodules
|
||||||
|
|
||||||
|
|
||||||
|
IGNORE_SUBMODULES = [
|
||||||
|
"convert_pytorch_checkpoint_to_tf2",
|
||||||
|
"modeling_flax_pytorch_utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def check_submodules():
|
||||||
|
# This is to make sure the transformers module imported is the one in the repo.
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"transformers",
|
||||||
|
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
|
||||||
|
submodule_search_locations=[PATH_TO_TRANSFORMERS],
|
||||||
|
)
|
||||||
|
transformers = spec.loader.load_module()
|
||||||
|
|
||||||
|
module_not_registered = [
|
||||||
|
module
|
||||||
|
for module in get_transformers_submodules()
|
||||||
|
if module not in IGNORE_SUBMODULES and module not in transformers._import_structure.keys()
|
||||||
|
]
|
||||||
|
if len(module_not_registered) > 0:
|
||||||
|
list_of_modules = "\n".join(f"- {module}" for module in module_not_registered)
|
||||||
|
raise ValueError(
|
||||||
|
"The following submodules are not properly registed in the main init of Transformers:\n"
|
||||||
|
f"{list_of_modules}\n"
|
||||||
|
"Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
check_all_inits()
|
check_all_inits()
|
||||||
|
check_submodules()
|
||||||
|
Loading…
Reference in New Issue
Block a user