From 0f968ddaa391ceefb177a633d7122da132b38efe Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 21 Jun 2023 17:31:57 +0200 Subject: [PATCH] Check auto mappings could be imported via `from transformers` (#24400) * fix * fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh --- src/transformers/__init__.py | 4 +++ src/transformers/models/auto/__init__.py | 4 +++ .../models/auto/modeling_flax_auto.py | 3 +++ src/transformers/utils/dummy_flax_objects.py | 3 +++ src/transformers/utils/dummy_pt_objects.py | 3 +++ utils/check_repo.py | 25 +++++++++++++++++++ 6 files changed, 42 insertions(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3217a9ca046..548e9bf660d 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1028,6 +1028,7 @@ else: _import_structure["models.auto"].extend( [ "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING", "MODEL_FOR_AUDIO_XVECTOR_MAPPING", "MODEL_FOR_BACKBONE_MAPPING", "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", @@ -3646,6 +3647,7 @@ else: ) _import_structure["models.auto"].extend( [ + "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING", @@ -4780,6 +4782,7 @@ if TYPE_CHECKING: ) from .models.auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING, MODEL_FOR_AUDIO_XVECTOR_MAPPING, MODEL_FOR_BACKBONE_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, @@ -6901,6 +6904,7 @@ if TYPE_CHECKING: FlaxAlbertPreTrainedModel, ) from .models.auto import ( + FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 5af79da56f7..df9958a80e4 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -40,6 +40,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["modeling_auto"] = [ "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING", "MODEL_FOR_AUDIO_XVECTOR_MAPPING", "MODEL_FOR_BACKBONE_MAPPING", "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", @@ -162,6 +163,7 @@ except OptionalDependencyNotAvailable: pass else: _import_structure["modeling_flax_auto"] = [ + "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING", @@ -207,6 +209,7 @@ if TYPE_CHECKING: else: from .modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, + MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING, MODEL_FOR_AUDIO_XVECTOR_MAPPING, MODEL_FOR_BACKBONE_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, @@ -329,6 +332,7 @@ if TYPE_CHECKING: pass else: from .modeling_flax_auto import ( + FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING, diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index e3b8d9cf5b5..44ef8444811 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -264,6 +264,9 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping( FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES ) +FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES +) class FlaxAutoModel(_BaseAutoModelClass): diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index ce571bc9f8d..cc3f48cb25e 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -135,6 +135,9 @@ class FlaxAlbertPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = None + + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5971244e66d..1490adf3b20 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -452,6 +452,9 @@ class ASTPreTrainedModel(metaclass=DummyObject): MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = None +MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = None + + MODEL_FOR_AUDIO_XVECTOR_MAPPING = None diff --git a/utils/check_repo.py b/utils/check_repo.py index 81dbfcb5f05..03edeab1778 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -711,6 +711,29 @@ def check_all_auto_mapping_names_in_config_mapping_names(): raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) +def check_all_auto_mappings_importable(): + """Check all auto mappings could be imported.""" + check_missing_backends() + + failures = [] + mappings_to_check = {} + # Each auto modeling files contains multiple mappings. Let's get them in a dynamic way. + for module_name in ["modeling_auto", "modeling_tf_auto", "modeling_flax_auto"]: + module = getattr(transformers.models.auto, module_name, None) + if module is None: + continue + # all mappings in a single auto modeling file + mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")] + mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) + + for name, _ in mappings_to_check.items(): + name = name.replace("_MAPPING_NAMES", "_MAPPING") + if not hasattr(transformers, name): + failures.append(f"`{name}` should be defined in the main `__init__` file.") + if len(failures) > 0: + raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) + + _re_decorator = re.compile(r"^\s*@(\S+)\s+$") @@ -993,6 +1016,8 @@ def check_repo_quality(): check_all_auto_object_names_being_defined() print("Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.") check_all_auto_mapping_names_in_config_mapping_names() + print("Checking all auto mappings could be imported.") + check_all_auto_mappings_importable() if __name__ == "__main__":