From 53735d7c3b484cc62af4d6341306daced97d6c5c Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 1 Mar 2023 17:53:29 +0100 Subject: [PATCH] Add an utility file to get information from test files (#21856) * Add an utility file to get information from test files --------- Co-authored-by: ydshieh --- .circleci/create_circleci_config.py | 2 +- tests/repo_utils/test_get_test_info.py | 109 ++++++++++++++ utils/get_test_info.py | 190 +++++++++++++++++++++++++ 3 files changed, 300 insertions(+), 1 deletion(-) create mode 100644 tests/repo_utils/test_get_test_info.py create mode 100644 utils/get_test_info.py diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index bf8c9e281a6..338f2508246 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -379,7 +379,7 @@ repo_utils_job = CircleCIJob( "repo_utils", install_steps=[ "pip install --upgrade pip", - "pip install .[quality,testing]", + "pip install .[quality,testing,torch]", ], parallelism=None, pytest_num_workers=1, diff --git a/tests/repo_utils/test_get_test_info.py b/tests/repo_utils/test_get_test_info.py new file mode 100644 index 00000000000..e432dd945ee --- /dev/null +++ b/tests/repo_utils/test_get_test_info.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + + +git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +sys.path.append(os.path.join(git_repo_path, "utils")) + +import get_test_info # noqa: E402 +from get_test_info import ( # noqa: E402 + get_model_to_test_mapping, + get_model_to_tester_mapping, + get_test_to_tester_mapping, +) + + +BERT_TEST_FILE = os.path.join("tests", "models", "bert", "test_modeling_bert.py") +BLIP_TEST_FILE = os.path.join("tests", "models", "blip", "test_modeling_blip.py") + + +class GetTestInfoTester(unittest.TestCase): + def test_get_test_to_tester_mapping(self): + bert_test_tester_mapping = get_test_to_tester_mapping(BERT_TEST_FILE) + blip_test_tester_mapping = get_test_to_tester_mapping(BLIP_TEST_FILE) + + EXPECTED_BERT_MAPPING = {"BertModelTest": "BertModelTester"} + + EXPECTED_BLIP_MAPPING = { + "BlipModelTest": "BlipModelTester", + "BlipTextImageModelTest": "BlipTextImageModelsModelTester", + "BlipTextModelTest": "BlipTextModelTester", + "BlipTextRetrievalModelTest": "BlipTextRetrievalModelTester", + "BlipVQAModelTest": "BlipModelTester", + "BlipVisionModelTest": "BlipVisionModelTester", + } + + self.assertEqual(get_test_info.to_json(bert_test_tester_mapping), EXPECTED_BERT_MAPPING) + self.assertEqual(get_test_info.to_json(blip_test_tester_mapping), EXPECTED_BLIP_MAPPING) + + def test_get_model_to_test_mapping(self): + bert_model_test_mapping = get_model_to_test_mapping(BERT_TEST_FILE) + blip_model_test_mapping = get_model_to_test_mapping(BLIP_TEST_FILE) + + EXPECTED_BERT_MAPPING = { + "BertForMaskedLM": ["BertModelTest"], + "BertForMultipleChoice": ["BertModelTest"], + "BertForNextSentencePrediction": ["BertModelTest"], + "BertForPreTraining": ["BertModelTest"], + "BertForQuestionAnswering": ["BertModelTest"], + "BertForSequenceClassification": ["BertModelTest"], + "BertForTokenClassification": ["BertModelTest"], + "BertLMHeadModel": ["BertModelTest"], + "BertModel": ["BertModelTest"], + } + + EXPECTED_BLIP_MAPPING = { + "BlipForConditionalGeneration": ["BlipTextImageModelTest"], + "BlipForImageTextRetrieval": ["BlipTextRetrievalModelTest"], + "BlipForQuestionAnswering": ["BlipTextImageModelTest", "BlipVQAModelTest"], + "BlipModel": ["BlipModelTest"], + "BlipTextModel": ["BlipTextModelTest"], + "BlipVisionModel": ["BlipVisionModelTest"], + } + + self.assertEqual(get_test_info.to_json(bert_model_test_mapping), EXPECTED_BERT_MAPPING) + self.assertEqual(get_test_info.to_json(blip_model_test_mapping), EXPECTED_BLIP_MAPPING) + + def test_get_model_to_tester_mapping(self): + bert_model_tester_mapping = get_model_to_tester_mapping(BERT_TEST_FILE) + blip_model_tester_mapping = get_model_to_tester_mapping(BLIP_TEST_FILE) + + EXPECTED_BERT_MAPPING = { + "BertForMaskedLM": ["BertModelTester"], + "BertForMultipleChoice": ["BertModelTester"], + "BertForNextSentencePrediction": ["BertModelTester"], + "BertForPreTraining": ["BertModelTester"], + "BertForQuestionAnswering": ["BertModelTester"], + "BertForSequenceClassification": ["BertModelTester"], + "BertForTokenClassification": ["BertModelTester"], + "BertLMHeadModel": ["BertModelTester"], + "BertModel": ["BertModelTester"], + } + + EXPECTED_BLIP_MAPPING = { + "BlipForConditionalGeneration": ["BlipTextImageModelsModelTester"], + "BlipForImageTextRetrieval": ["BlipTextRetrievalModelTester"], + "BlipForQuestionAnswering": ["BlipModelTester", "BlipTextImageModelsModelTester"], + "BlipModel": ["BlipModelTester"], + "BlipTextModel": ["BlipTextModelTester"], + "BlipVisionModel": ["BlipVisionModelTester"], + } + + self.assertEqual(get_test_info.to_json(bert_model_tester_mapping), EXPECTED_BERT_MAPPING) + self.assertEqual(get_test_info.to_json(blip_model_tester_mapping), EXPECTED_BLIP_MAPPING) diff --git a/utils/get_test_info.py b/utils/get_test_info.py new file mode 100644 index 00000000000..d6b451e71f3 --- /dev/null +++ b/utils/get_test_info.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +import sys + + +# This is required to make the module import works (when the python process is running from the root of the repo) +sys.path.append(".") + + +r""" +The argument `test_file` in this file refers to a model test file. This should be a string of the from +`tests/models/*/test_modeling_*.py`. +""" + + +def get_module_path(test_file): + """Return the module path of a model test file.""" + components = test_file.split(os.path.sep) + if components[0:2] != ["tests", "models"]: + raise ValueError( + "`test_file` should start with `tests/models/` (with `/` being the OS specific path separator). Got " + f"{test_file} instead." + ) + test_fn = components[-1] + if not test_fn.endswith("py"): + raise ValueError(f"`test_file` should be a python file. Got {test_fn} instead.") + if not test_fn.startswith("test_modeling_"): + raise ValueError( + f"`test_file` should point to a file name of the form `test_modeling_*.py`. Got {test_fn} instead." + ) + + components = components[:-1] + [test_fn.replace(".py", "")] + test_module_path = ".".join(components) + + return test_module_path + + +def get_test_module(test_file): + """Get the module of a model test file.""" + test_module_path = get_module_path(test_file) + test_module = importlib.import_module(test_module_path) + + return test_module + + +def get_tester_classes(test_file): + """Get all classes in a model test file whose names ends with `ModelTester`.""" + tester_classes = [] + test_module = get_test_module(test_file) + for attr in dir(test_module): + if attr.endswith("ModelTester"): + tester_classes.append(getattr(test_module, attr)) + + # sort with class names + return sorted(tester_classes, key=lambda x: x.__name__) + + +def get_test_classes(test_file): + """Get all [test] classes in a model test file with attribute `all_model_classes` that are non-empty. + + These are usually the (model) test classes containing the (non-slow) tests to run and are subclasses of one of the + classes `ModelTesterMixin`, `TFModelTesterMixin` or `FlaxModelTesterMixin`, as well as a subclass of + `unittest.TestCase`. Exceptions include `RagTestMixin` (and its subclasses). + """ + test_classes = [] + test_module = get_test_module(test_file) + for attr in dir(test_module): + attr_value = getattr(test_module, attr) + # (TF/Flax)ModelTesterMixin is also an attribute in specific model test module. Let's exclude them by checking + # `all_model_classes` is not empty (which also excludes other special classes). + model_classes = getattr(attr_value, "all_model_classes", []) + if len(model_classes) > 0: + test_classes.append(attr_value) + + # sort with class names + return sorted(test_classes, key=lambda x: x.__name__) + + +def get_model_classes(test_file): + """Get all model classes that appear in `all_model_classes` attributes in a model test file.""" + test_classes = get_test_classes(test_file) + model_classes = set() + for test_class in test_classes: + model_classes.update(test_class.all_model_classes) + + # sort with class names + return sorted(model_classes, key=lambda x: x.__name__) + + +def get_model_tester_from_test_class(test_class): + """Get the model tester class of a model test class.""" + test = test_class() + if hasattr(test, "setUp"): + test.setUp() + + model_tester = None + if hasattr(test, "model_tester"): + # `(TF/Flax)ModelTesterMixin` has this attribute default to `None`. Let's skip this case. + if test.model_tester is not None: + model_tester = test.model_tester.__class__ + + return model_tester + + +def get_test_classes_for_model(test_file, model_class): + """Get all [test] classes in `test_file` that have `model_class` in their `all_model_classes`.""" + test_classes = get_test_classes(test_file) + + target_test_classes = [] + for test_class in test_classes: + if model_class in test_class.all_model_classes: + target_test_classes.append(test_class) + + # sort with class names + return sorted(target_test_classes, key=lambda x: x.__name__) + + +def get_tester_classes_for_model(test_file, model_class): + """Get all model tester classes in `test_file` that are associated to `model_class`.""" + test_classes = get_test_classes_for_model(test_file, model_class) + + tester_classes = [] + for test_class in test_classes: + tester_class = get_model_tester_from_test_class(test_class) + if tester_class is not None: + tester_classes.append(tester_class) + + # sort with class names + return sorted(tester_classes, key=lambda x: x.__name__) + + +def get_test_to_tester_mapping(test_file): + """Get a mapping from [test] classes to model tester classes in `test_file`. + + This uses `get_test_classes` which may return classes that are NOT subclasses of `unittest.TestCase`. + """ + test_classes = get_test_classes(test_file) + test_tester_mapping = {test_class: get_model_tester_from_test_class(test_class) for test_class in test_classes} + return test_tester_mapping + + +def get_model_to_test_mapping(test_file): + """Get a mapping from model classes to test classes in `test_file`.""" + model_classes = get_model_classes(test_file) + model_test_mapping = { + model_class: get_test_classes_for_model(test_file, model_class) for model_class in model_classes + } + return model_test_mapping + + +def get_model_to_tester_mapping(test_file): + """Get a mapping from model classes to model tester classes in `test_file`.""" + model_classes = get_model_classes(test_file) + model_to_tester_mapping = { + model_class: get_tester_classes_for_model(test_file, model_class) for model_class in model_classes + } + return model_to_tester_mapping + + +def to_json(o): + """Make the information succinct and easy to read. + + Avoid the full class representation like `` when + displaying the results. Instead, we use class name (`BertForMaskedLM`) for the readability. + """ + if isinstance(o, str): + return o + elif isinstance(o, type): + return o.__name__ + elif isinstance(o, (list, tuple)): + return [to_json(x) for x in o] + elif isinstance(o, dict): + return {to_json(k): to_json(v) for k, v in o.items()} + else: + return o