From 6ba540b7475f095d591b4766cac897007b1d5db0 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 7 Aug 2020 09:18:37 -0400 Subject: [PATCH] Add a script to check all models are tested and documented (#6298) * Add a script to check all models are tested and documented * Apply suggestions from code review Co-authored-by: Kevin Canwen Xu * Address comments Co-authored-by: Kevin Canwen Xu --- Makefile | 1 + docs/source/model_doc/bert.rst | 14 + docs/source/model_doc/camembert.rst | 7 + docs/source/model_doc/flaubert.rst | 7 + docs/source/model_doc/reformer.rst | 7 + docs/source/model_doc/xlm.rst | 14 + src/transformers/modeling_electra.py | 2 + src/transformers/modeling_tf_albert.py | 9 +- tests/test_modeling_electra.py | 1 + tests/test_modeling_tf_albert.py | 22 ++ ...enai_gpt.py => test_modeling_tf_openai.py} | 0 utils/check_repo.py | 269 ++++++++++++++++++ 12 files changed, 351 insertions(+), 2 deletions(-) rename tests/{test_modeling_tf_openai_gpt.py => test_modeling_tf_openai.py} (100%) create mode 100644 utils/check_repo.py diff --git a/Makefile b/Makefile index dc2a6491ee8..48edc234f70 100644 --- a/Makefile +++ b/Makefile @@ -6,6 +6,7 @@ quality: black --check --line-length 119 --target-version py35 examples templates tests src utils isort --check-only --recursive examples templates tests src utils flake8 examples templates tests src utils + python utils/check_repo.py # Format source code automatically diff --git a/docs/source/model_doc/bert.rst b/docs/source/model_doc/bert.rst index 5e35b520d86..13bc47e260d 100644 --- a/docs/source/model_doc/bert.rst +++ b/docs/source/model_doc/bert.rst @@ -78,6 +78,13 @@ BertForPreTraining :members: +BertModelLMHeadModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BertLMHeadModel + :members: + + BertForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -134,6 +141,13 @@ TFBertForPreTraining :members: +TFBertModelLMHeadModel +~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFBertLMHeadModel + :members: + + TFBertForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/camembert.rst b/docs/source/model_doc/camembert.rst index b2f28842dbf..5ccdfe5b877 100644 --- a/docs/source/model_doc/camembert.rst +++ b/docs/source/model_doc/camembert.rst @@ -105,6 +105,13 @@ TFCamembertForSequenceClassification :members: +TFCamembertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFCamembertForMultipleChoice + :members: + + TFCamembertForTokenClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/flaubert.rst b/docs/source/model_doc/flaubert.rst index 4a0e4ca5811..e454f96cba7 100644 --- a/docs/source/model_doc/flaubert.rst +++ b/docs/source/model_doc/flaubert.rst @@ -61,6 +61,13 @@ FlaubertForSequenceClassification :members: +FlaubertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaubertForMultipleChoice + :members: + + FlaubertForTokenClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst index f82d972b72a..370187dfb89 100644 --- a/docs/source/model_doc/reformer.rst +++ b/docs/source/model_doc/reformer.rst @@ -121,6 +121,13 @@ ReformerForMaskedLM :members: +ReformerForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerForSequenceClassification + :members: + + ReformerForQuestionAnswering ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/xlm.rst b/docs/source/model_doc/xlm.rst index cd14a77cbbf..9f6254d01ce 100644 --- a/docs/source/model_doc/xlm.rst +++ b/docs/source/model_doc/xlm.rst @@ -75,6 +75,20 @@ XLMForSequenceClassification :members: +XLMForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.XLMForMultipleChoice + :members: + + +XLMForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.XLMForTokenClassification + :members: + + XLMForQuestionAnsweringSimple ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/modeling_electra.py b/src/transformers/modeling_electra.py index 1eb58c1486d..f41c230ca7a 100644 --- a/src/transformers/modeling_electra.py +++ b/src/transformers/modeling_electra.py @@ -879,6 +879,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): inputs_embeds=None, labels=None, output_attentions=None, + output_hidden_states=None, return_dict=None, ): r""" @@ -908,6 +909,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel): head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, + output_hidden_states=output_hidden_states, return_dict=return_dict, ) diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 73a1e825274..086a7ac24d3 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -1260,7 +1260,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_attentions = inputs.get("output_attentions", output_attentions) - output_hidden_states = inputs.get("output_hidden_states", output_attentions) + output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs.get("return_dict", return_dict) labels = inputs.get("labels", labels) assert len(inputs) <= 10, "Too many inputs." @@ -1279,6 +1279,11 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + flat_inputs_embeds = ( + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) outputs = self.albert( flat_input_ids, @@ -1286,7 +1291,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss): flat_token_type_ids, flat_position_ids, head_mask, - inputs_embeds, + flat_inputs_embeds, output_attentions, output_hidden_states, return_dict=return_dict, diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 935f4a27298..88c0eafa578 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -275,6 +275,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): ElectraModel, ElectraForPreTraining, ElectraForMaskedLM, + ElectraForMultipleChoice, ElectraForTokenClassification, ElectraForSequenceClassification, ElectraForQuestionAnswering, diff --git a/tests/test_modeling_tf_albert.py b/tests/test_modeling_tf_albert.py index ca807e84877..58a832cfdce 100644 --- a/tests/test_modeling_tf_albert.py +++ b/tests/test_modeling_tf_albert.py @@ -24,10 +24,12 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor if is_tf_available(): + import tensorflow as tf from transformers.modeling_tf_albert import ( TFAlbertModel, TFAlbertForPreTraining, TFAlbertForMaskedLM, + TFAlbertForMultipleChoice, TFAlbertForSequenceClassification, TFAlbertForQuestionAnswering, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -180,6 +182,22 @@ class TFAlbertModelTester: self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length]) + def create_and_check_albert_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = TFAlbertForMultipleChoice(config=config) + multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1)) + multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1)) + multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1)) + inputs = { + "input_ids": multiple_choice_inputs_ids, + "attention_mask": multiple_choice_input_mask, + "token_type_ids": multiple_choice_token_type_ids, + } + result = model(inputs) + self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -229,6 +247,10 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_albert_for_masked_lm(*config_and_inputs) + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_albert_for_multiple_choice(*config_and_inputs) + def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_albert_for_sequence_classification(*config_and_inputs) diff --git a/tests/test_modeling_tf_openai_gpt.py b/tests/test_modeling_tf_openai.py similarity index 100% rename from tests/test_modeling_tf_openai_gpt.py rename to tests/test_modeling_tf_openai.py diff --git a/utils/check_repo.py b/utils/check_repo.py new file mode 100644 index 00000000000..ca3743b2657 --- /dev/null +++ b/utils/check_repo.py @@ -0,0 +1,269 @@ +import importlib +import inspect +import os +import re + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_repo.py +PATH_TO_TRANSFORMERS = "src/transformers" +PATH_TO_TESTS = "tests" +PATH_TO_DOC = "docs/source/model_doc" + +# Update this list for models that are not tested with a comment explaining the reason it should not be. +# Being in this list is an exception and should **not** be the rule. +IGNORE_NON_TESTED = [ + "BertLMHeadModel", # Needs to be setup as decoder. + "DPREncoder", # Building part of bigger (tested) model. + "DPRSpanPredictor", # Building part of bigger (tested) model. + "ReformerForMaskedLM", # Needs to be setup as decoder. + "T5Stack", # Building part of bigger (tested) model. + "TFAlbertForMultipleChoice", # TODO: fix + "TFAlbertForTokenClassification", # TODO: fix + "TFBertLMHeadModel", # TODO: fix + "TFElectraForMultipleChoice", # Fix is in #6284 + "TFElectraForQuestionAnswering", # TODO: fix + "TFElectraForSequenceClassification", # Fix is in #6284 + "TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?) + "TFRobertaForMultipleChoice", # TODO: fix +] + +# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't +# trigger the common tests. +TEST_FILES_WITH_NO_COMMON_TESTS = [ + "test_modeling_camembert.py", + "test_modeling_tf_camembert.py", + "test_modeling_tf_xlm_roberta.py", + "test_modeling_xlm_roberta.py", +] + +# Update this list for models that are not documented with a comment explaining the reason it should not be. +# Being in this list is an exception and should **not** be the rule. +IGNORE_NON_DOCUMENTED = [ + "DPREncoder", # Building part of bigger (documented) model. + "DPRSpanPredictor", # Building part of bigger (documented) model. + "T5Stack", # Building part of bigger (tested) model. + "TFElectraMainLayer", # Building part of bigger (documented) model (should it be a TFPreTrainedModel ?) +] + +# Update this dict with any special correspondance model name (used in modeling_xxx.py) to doc file. +MODEL_NAME_TO_DOC_FILE = { + "openai": "gpt.rst", + "transfo_xl": "transformerxl.rst", + "xlm_roberta": "xlmroberta.rst", +} + +# This is to make sure the transformers module imported is the one in the repo. +spec = importlib.util.spec_from_file_location( + "transformers", + os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), + submodule_search_locations=[PATH_TO_TRANSFORMERS], +) +transformers = spec.loader.load_module() + + +# If some modeling modules should be ignored for all checks, they should be added in the nested list +# _ignore_modules of this function. +def get_model_modules(): + """ Get the model modules inside the transformers library. """ + _ignore_modules = [ + "modeling_auto", + "modeling_encoder_decoder", + "modeling_marian", + "modeling_mmbt", + "modeling_outputs", + "modeling_retribert", + "modeling_utils", + "modeling_transfo_xl_utilities", + "modeling_tf_auto", + "modeling_tf_outputs", + "modeling_tf_pytorch_utils", + "modeling_tf_utils", + "modeling_tf_transfo_xl_utilities", + ] + modules = [] + for attr_name in dir(transformers): + if attr_name.startswith("modeling") and attr_name not in _ignore_modules: + module = getattr(transformers, attr_name) + if inspect.ismodule(module): + modules.append(module) + return modules + + +def get_models(module): + """ Get the objects in module that are models.""" + models = [] + model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel) + for attr_name in dir(module): + if "Pretrained" in attr_name or "PreTrained" in attr_name: + continue + attr = getattr(module, attr_name) + if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__: + models.append((attr_name, attr)) + return models + + +# If some test_modeling files should be ignored when checking models are all tested, they should be added in the +# nested list _ignore_files of this function. +def get_model_test_files(): + """ Get the model test files.""" + _ignore_files = [ + "test_modeling_common", + "test_modeling_encoder_decoder", + "test_modeling_marian", + "test_modeling_mbart", + "test_modeling_tf_common", + ] + test_files = [] + for filename in os.listdir(PATH_TO_TESTS): + if ( + os.path.isfile(f"{PATH_TO_TESTS}/{filename}") + and filename.startswith("test_modeling") + and not os.path.splitext(filename)[0] in _ignore_files + ): + test_files.append(filename) + return test_files + + +# If some doc source files should be ignored when checking models are all documented, they should be added in the +# nested list _ignore_modules of this function. +def get_model_doc_files(): + """ Get the model doc files.""" + _ignore_modules = [ + "auto", + "dialogpt", + "marian", + "retribert", + ] + doc_files = [] + for filename in os.listdir(PATH_TO_DOC): + if os.path.isfile(f"{PATH_TO_DOC}/{filename}") and not os.path.splitext(filename)[0] in _ignore_modules: + doc_files.append(filename) + return doc_files + + +# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class +# for the all_model_classes variable. +def find_tested_models(test_file): + """ Parse the content of test_file to detect what's in all_model_classes""" + with open(os.path.join(PATH_TO_TESTS, test_file)) as f: + content = f.read() + all_models = re.search(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content) + # Check with one less parenthesis + if all_models is None: + all_models = re.search(r"all_model_classes\s+=\s+\(([^\)]*)\)", content) + if all_models is not None: + model_tested = [] + for line in all_models.groups()[0].split(","): + name = line.strip() + if len(name) > 0: + model_tested.append(name) + return model_tested + + +def check_models_are_tested(module, test_file): + """ Check models defined in module are tested in test_file.""" + defined_models = get_models(module) + tested_models = find_tested_models(test_file) + if tested_models is None: + if test_file in TEST_FILES_WITH_NO_COMMON_TESTS: + return + return [ + f"{test_file} should define `all_model_classes` to apply common tests to the models it tests. " + + "If this intentional, add the test filename to `TEST_FILES_WITH_NO_COMMON_TESTS` in the file " + + "`utils/check_repo.py`." + ] + failures = [] + for model_name, _ in defined_models: + if model_name not in tested_models and model_name not in IGNORE_NON_TESTED: + failures.append( + f"{model_name} is defined in {module.__name__} but is not tested in " + + f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file." + + "If common tests should not applied to that model, add its name to `IGNORE_NON_TESTED`" + + "in the file `utils/check_repo.py`." + ) + return failures + + +def check_all_models_are_tested(): + """ Check all models are properly tested.""" + modules = get_model_modules() + test_files = get_model_test_files() + failures = [] + for module in modules: + test_file = f"test_{module.__name__.split('.')[1]}.py" + if test_file not in test_files: + failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.") + new_failures = check_models_are_tested(module, test_file) + if new_failures is not None: + failures += new_failures + if len(failures) > 0: + raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) + + +def find_documented_classes(doc_file): + """ Parse the content of doc_file to detect which classes it documents""" + with open(os.path.join(PATH_TO_DOC, doc_file)) as f: + content = f.read() + return re.findall(r"autoclass:: transformers.(\S+)\s+", content) + + +def check_models_are_documented(module, doc_file): + """ Check models defined in module are documented in doc_file.""" + defined_models = get_models(module) + documented_classes = find_documented_classes(doc_file) + failures = [] + for model_name, _ in defined_models: + if model_name not in documented_classes and model_name not in IGNORE_NON_DOCUMENTED: + failures.append( + f"{model_name} is defined in {module.__name__} but is not documented in " + + f"{os.path.join(PATH_TO_DOC, doc_file)}. Add it to that file." + + "If this model should not be documented, add its name to `IGNORE_NON_DOCUMENTED`" + + "in the file `utils/check_repo.py`." + ) + return failures + + +def _get_model_name(module): + """ Get the model name for the module defining it.""" + splits = module.__name__.split("_") + # Secial case for transfo_xl + if splits[-1] == "xl": + return "_".join(splits[-2:]) + # Secial case for xlm_roberta + if splits[-1] == "roberta" and splits[-2] == "xlm": + return "_".join(splits[-2:]) + return splits[-1] + + +def check_all_models_are_documented(): + """ Check all models are properly documented.""" + modules = get_model_modules() + doc_files = get_model_doc_files() + failures = [] + for module in modules: + model_name = _get_model_name(module) + doc_file = MODEL_NAME_TO_DOC_FILE.get(model_name, f"{model_name}.rst") + if doc_file not in doc_files: + failures.append( + f"{module.__name__} does not have its corresponding doc file {doc_file}. " + + f"If the doc file exists but isn't named {doc_file}, update `MODEL_NAME_TO_DOC_FILE` " + + "in the file `utils/check_repo.py`." + ) + new_failures = check_models_are_documented(module, doc_file) + if new_failures is not None: + failures += new_failures + if len(failures) > 0: + raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) + + +def check_repo_quality(): + """ Check all models are properly tested and documented.""" + print("Checking all models are properly tested.") + check_all_models_are_tested() + print("Checking all models are properly documented.") + check_all_models_are_documented() + + +if __name__ == "__main__": + check_repo_quality()