mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Check all models are in an auto class (#8425)
This commit is contained in:
parent
ef032ddd1e
commit
a39218b75b
@ -31,6 +31,7 @@ from .configuration_auto import (
|
|||||||
FunnelConfig,
|
FunnelConfig,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
LongformerConfig,
|
LongformerConfig,
|
||||||
|
LxmertConfig,
|
||||||
MobileBertConfig,
|
MobileBertConfig,
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
@ -113,6 +114,7 @@ from .modeling_tf_funnel import (
|
|||||||
)
|
)
|
||||||
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
||||||
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
|
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
|
||||||
|
from .modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
|
||||||
from .modeling_tf_marian import TFMarianMTModel
|
from .modeling_tf_marian import TFMarianMTModel
|
||||||
from .modeling_tf_mbart import TFMBartForConditionalGeneration
|
from .modeling_tf_mbart import TFMBartForConditionalGeneration
|
||||||
from .modeling_tf_mobilebert import (
|
from .modeling_tf_mobilebert import (
|
||||||
@ -168,6 +170,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
TF_MODEL_MAPPING = OrderedDict(
|
TF_MODEL_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
(LxmertConfig, TFLxmertModel),
|
||||||
(T5Config, TFT5Model),
|
(T5Config, TFT5Model),
|
||||||
(DistilBertConfig, TFDistilBertModel),
|
(DistilBertConfig, TFDistilBertModel),
|
||||||
(AlbertConfig, TFAlbertModel),
|
(AlbertConfig, TFAlbertModel),
|
||||||
@ -192,6 +195,7 @@ TF_MODEL_MAPPING = OrderedDict(
|
|||||||
|
|
||||||
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
|
(LxmertConfig, TFLxmertForPreTraining),
|
||||||
(T5Config, TFT5ForConditionalGeneration),
|
(T5Config, TFT5ForConditionalGeneration),
|
||||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||||
(AlbertConfig, TFAlbertForPreTraining),
|
(AlbertConfig, TFAlbertForPreTraining),
|
||||||
|
@ -70,6 +70,34 @@ MODEL_NAME_TO_DOC_FILE = {
|
|||||||
"marian": "marian.rst",
|
"marian": "marian.rst",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
||||||
|
# should **not** be the rule.
|
||||||
|
IGNORE_NON_AUTO_CONFIGURED = [
|
||||||
|
"DPRContextEncoder",
|
||||||
|
"DPREncoder",
|
||||||
|
"DPRReader",
|
||||||
|
"DPRSpanPredictor",
|
||||||
|
"FlaubertForQuestionAnswering",
|
||||||
|
"FunnelBaseModel",
|
||||||
|
"GPT2DoubleHeadsModel",
|
||||||
|
"OpenAIGPTDoubleHeadsModel",
|
||||||
|
"ProphetNetDecoder",
|
||||||
|
"ProphetNetEncoder",
|
||||||
|
"RagModel",
|
||||||
|
"RagSequenceForGeneration",
|
||||||
|
"RagTokenForGeneration",
|
||||||
|
"T5Stack",
|
||||||
|
"TFBertForNextSentencePrediction",
|
||||||
|
"TFFunnelBaseModel",
|
||||||
|
"TFGPT2DoubleHeadsModel",
|
||||||
|
"TFMobileBertForNextSentencePrediction",
|
||||||
|
"TFOpenAIGPTDoubleHeadsModel",
|
||||||
|
"XLMForQuestionAnswering",
|
||||||
|
"XLMProphetNetDecoder",
|
||||||
|
"XLMProphetNetEncoder",
|
||||||
|
"XLNetForQuestionAnswering",
|
||||||
|
]
|
||||||
|
|
||||||
# This is to make sure the transformers module imported is the one in the repo.
|
# This is to make sure the transformers module imported is the one in the repo.
|
||||||
spec = importlib.util.spec_from_file_location(
|
spec = importlib.util.spec_from_file_location(
|
||||||
"transformers",
|
"transformers",
|
||||||
@ -282,6 +310,45 @@ def check_all_models_are_documented():
|
|||||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_auto_configured_models():
|
||||||
|
""" Return the list of all models in at least one auto class."""
|
||||||
|
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||||
|
for attr_name in dir(transformers.modeling_auto):
|
||||||
|
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
|
||||||
|
result = result | set(getattr(transformers.modeling_auto, attr_name).values())
|
||||||
|
for attr_name in dir(transformers.modeling_tf_auto):
|
||||||
|
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
|
||||||
|
result = result | set(getattr(transformers.modeling_tf_auto, attr_name).values())
|
||||||
|
return [cls.__name__ for cls in result]
|
||||||
|
|
||||||
|
|
||||||
|
def check_models_are_auto_configured(module, all_auto_models):
|
||||||
|
""" Check models defined in module are each in an auto class."""
|
||||||
|
defined_models = get_models(module)
|
||||||
|
failures = []
|
||||||
|
for model_name, _ in defined_models:
|
||||||
|
if model_name not in all_auto_models and model_name not in IGNORE_NON_AUTO_CONFIGURED:
|
||||||
|
failures.append(
|
||||||
|
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
|
||||||
|
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
|
||||||
|
"`utils/check_repo.py`."
|
||||||
|
)
|
||||||
|
return failures
|
||||||
|
|
||||||
|
|
||||||
|
def check_all_models_are_auto_configured():
|
||||||
|
""" Check all models are each in an auto class."""
|
||||||
|
modules = get_model_modules()
|
||||||
|
all_auto_models = get_all_auto_configured_models()
|
||||||
|
failures = []
|
||||||
|
for module in modules:
|
||||||
|
new_failures = check_models_are_auto_configured(module, all_auto_models)
|
||||||
|
if new_failures is not None:
|
||||||
|
failures += new_failures
|
||||||
|
if len(failures) > 0:
|
||||||
|
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||||
|
|
||||||
|
|
||||||
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||||
|
|
||||||
|
|
||||||
@ -325,6 +392,8 @@ def check_repo_quality():
|
|||||||
check_all_models_are_tested()
|
check_all_models_are_tested()
|
||||||
print("Checking all models are properly documented.")
|
print("Checking all models are properly documented.")
|
||||||
check_all_models_are_documented()
|
check_all_models_are_documented()
|
||||||
|
print("Checking all models are in at least one auto class.")
|
||||||
|
check_all_models_are_auto_configured()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user