Check all models are in an auto class (#8425)

This commit is contained in:
Sylvain Gugger 2020-11-09 15:44:54 -05:00 committed by GitHub
parent ef032ddd1e
commit a39218b75b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 0 deletions

View File

@ -31,6 +31,7 @@ from .configuration_auto import (
FunnelConfig,
GPT2Config,
LongformerConfig,
LxmertConfig,
MobileBertConfig,
OpenAIGPTConfig,
RobertaConfig,
@ -113,6 +114,7 @@ from .modeling_tf_funnel import (
)
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
from .modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
from .modeling_tf_marian import TFMarianMTModel
from .modeling_tf_mbart import TFMBartForConditionalGeneration
from .modeling_tf_mobilebert import (
@ -168,6 +170,7 @@ logger = logging.get_logger(__name__)
TF_MODEL_MAPPING = OrderedDict(
[
(LxmertConfig, TFLxmertModel),
(T5Config, TFT5Model),
(DistilBertConfig, TFDistilBertModel),
(AlbertConfig, TFAlbertModel),
@ -192,6 +195,7 @@ TF_MODEL_MAPPING = OrderedDict(
TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
(LxmertConfig, TFLxmertForPreTraining),
(T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForPreTraining),

View File

@ -70,6 +70,34 @@ MODEL_NAME_TO_DOC_FILE = {
"marian": "marian.rst",
}
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [
"DPRContextEncoder",
"DPREncoder",
"DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering",
"FunnelBaseModel",
"GPT2DoubleHeadsModel",
"OpenAIGPTDoubleHeadsModel",
"ProphetNetDecoder",
"ProphetNetEncoder",
"RagModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
"T5Stack",
"TFBertForNextSentencePrediction",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel",
"TFMobileBertForNextSentencePrediction",
"TFOpenAIGPTDoubleHeadsModel",
"XLMForQuestionAnswering",
"XLMProphetNetDecoder",
"XLMProphetNetEncoder",
"XLNetForQuestionAnswering",
]
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
@ -282,6 +310,45 @@ def check_all_models_are_documented():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def get_all_auto_configured_models():
""" Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.modeling_auto, attr_name).values())
for attr_name in dir(transformers.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.modeling_tf_auto, attr_name).values())
return [cls.__name__ for cls in result]
def check_models_are_auto_configured(module, all_auto_models):
""" Check models defined in module are each in an auto class."""
defined_models = get_models(module)
failures = []
for model_name, _ in defined_models:
if model_name not in all_auto_models and model_name not in IGNORE_NON_AUTO_CONFIGURED:
failures.append(
f"{model_name} is defined in {module.__name__} but is not present in any of the auto mapping. "
"If that is intended behavior, add its name to `IGNORE_NON_AUTO_CONFIGURED` in the file "
"`utils/check_repo.py`."
)
return failures
def check_all_models_are_auto_configured():
""" Check all models are each in an auto class."""
modules = get_model_modules()
all_auto_models = get_all_auto_configured_models()
failures = []
for module in modules:
new_failures = check_models_are_auto_configured(module, all_auto_models)
if new_failures is not None:
failures += new_failures
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
@ -325,6 +392,8 @@ def check_repo_quality():
check_all_models_are_tested()
print("Checking all models are properly documented.")
check_all_models_are_documented()
print("Checking all models are in at least one auto class.")
check_all_models_are_auto_configured()
if __name__ == "__main__":