From 084873b0258d45ad4b1882d12976b694e03c5c7c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 14 Jul 2021 10:56:55 -0400 Subject: [PATCH] Only test the files impacted by changes in the diff (#12644) * Base test * More test * Fix mistake * Add a docstring change * Add doc ignore * Add changes * Add recursive dep search * Add recursive dep search * save * Finalize test mapping * Fix bug * Print prettier * Ignore comments and empty lines * Make script runnable from anywhere * Need dev install * Like that * Adapt * Add as artifact * Try on torch tests * Fix yaml error * Install GitPython * Apply everywhere * Be more defensive * Revert to all tests if something is wrong * Install GitPython * Test if there are tests before launching. * Fixes * Fixes * Fixes * Fixes * Bash syntax is horrible * Be less stupid * Try differently * Typo * Typo * Typo * Style * Better name * Escape quotes * Ignore black unhelpful re-formatting * Not a docstring * Deal with inits in dependency map * Run all tests once PR is merged. * Add last job * Apply suggestions from code review Co-authored-by: Stas Bekman * Stronger dependencies gather * Ignore empty lines too! * Clean up * Fix quality Co-authored-by: Stas Bekman --- .circleci/config.yml | 88 +++- Makefile | 1 + setup.py | 3 +- src/transformers/dependency_versions_table.py | 1 + tests/conftest.py | 6 + utils/style_doc.py | 6 +- utils/tests_fetcher.py | 427 ++++++++++++++++++ 7 files changed, 517 insertions(+), 15 deletions(-) create mode 100644 utils/tests_fetcher.py diff --git a/.circleci/config.yml b/.circleci/config.yml index f76343ac665..ec29b0fb74a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -86,7 +86,13 @@ jobs: key: v0.4-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_tf ./tests/ -m is_pt_tf_cross_test --durations=0 | tee tests_output.txt + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_tf $(cat test_list.txt) -m is_pt_tf_cross_test --durations=0 | tee tests_output.txt + fi - store_artifacts: path: ~/transformers/tests_output.txt - store_artifacts: @@ -116,7 +122,13 @@ jobs: key: v0.4-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_flax ./tests/ -m is_pt_flax_cross_test --durations=0 | tee tests_output.txt + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_flax $(cat test_list.txt) -m is_pt_flax_cross_test --durations=0 | tee tests_output.txt + fi - store_artifacts: path: ~/transformers/tests_output.txt - store_artifacts: @@ -145,7 +157,13 @@ jobs: key: v0.4-torch-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -n 3 --dist=loadfile -s --make-reports=tests_torch ./tests/ | tee tests_output.txt + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -n 3 --dist=loadfile -s --make-reports=tests_torch $(cat test_list.txt) | tee tests_output.txt + fi - store_artifacts: path: ~/transformers/tests_output.txt - store_artifacts: @@ -172,7 +190,13 @@ jobs: key: v0.4-tf-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_tf ./tests/ | tee tests_output.txt + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_tf $(cat test_list.txt) | tee tests_output.txt + fi - store_artifacts: path: ~/transformers/tests_output.txt - store_artifacts: @@ -199,7 +223,13 @@ jobs: key: v0.4-flax-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_flax ./tests/ | tee tests_output.txt + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_flax $(cat test_list.txt) | tee tests_output.txt + fi - store_artifacts: path: ~/transformers/tests_output.txt - store_artifacts: @@ -229,7 +259,13 @@ jobs: key: v0.4-torch-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_pipelines_torch -m is_pipeline_test ./tests/ | tee tests_output.txt + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_pipelines_torch -m is_pipeline_test $(cat test_list.txt) | tee tests_output.txt + fi - store_artifacts: path: ~/transformers/tests_output.txt - store_artifacts: @@ -257,7 +293,13 @@ jobs: key: v0.4-tf-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_pipelines_tf ./tests/ -m is_pipeline_test | tee tests_output.txt + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_pipelines_tf $(cat test_list.txt) -m is_pipeline_test | tee tests_output.txt + fi - store_artifacts: path: ~/transformers/tests_output.txt - store_artifacts: @@ -283,7 +325,10 @@ jobs: key: v0.4-custom_tokenizers-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -s --make-reports=tests_custom_tokenizers ./tests/test_tokenization_bert_japanese.py | tee tests_output.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -s --make-reports=tests_custom_tokenizers ./tests/test_tokenization_bert_japanese.py | tee tests_output.txt + fi - store_artifacts: path: ~/transformers/tests_output.txt - store_artifacts: @@ -311,7 +356,13 @@ jobs: key: v0.4-torch_examples-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: TRANSFORMERS_IS_CI=1 python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_torch ./examples/pytorch/ | tee examples_output.txt + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + TRANSFORMERS_IS_CI=1 python -m pytest -n 8 --dist=loadfile -s --make-reports=examples_torch ./examples/pytorch/ | tee examples_output.txt + fi - store_artifacts: path: ~/transformers/examples_output.txt - store_artifacts: @@ -343,7 +394,13 @@ jobs: key: v0.4-hub-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -sv ./tests/ -m is_staging_test + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -sv $(cat test_list.txt) -m is_staging_test + fi run_tests_onnxruntime: working_directory: ~/transformers @@ -366,7 +423,13 @@ jobs: key: v0.4-onnx-{{ checksum "setup.py" }} paths: - '~/.cache/pip' - - run: python -m pytest -n 1 --dist=loadfile -s --make-reports=tests_torch ./tests/* -k onnx | tee tests_output.txt + - run: python utils/tests_fetcher.py | tee test_preparation.txt + - store_artifacts: + path: ~/transformers/test_preparation.txt + - run: | + if [ -f test_list.txt ]; then + python -m pytest -n 1 --dist=loadfile -s --make-reports=tests_torch $(cat test_list.txt) -k onnx | tee tests_output.txt + fi - store_artifacts: path: ~/transformers/tests_output.txt - store_artifacts: @@ -431,7 +494,7 @@ jobs: - v0.4-code_quality-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }} - run: pip install --upgrade pip - - run: pip install isort + - run: pip install isort GitPython - run: pip install .[all,quality] - save_cache: key: v0.4-code_quality-{{ checksum "setup.py" }} @@ -448,6 +511,7 @@ jobs: - run: python utils/check_repo.py - run: python utils/check_inits.py - run: make deps_table_check_updated + - run: python utils/tests_fetcher.py --sanity_check check_repository_consistency: working_directory: ~/transformers diff --git a/Makefile b/Makefile index 28645600cec..4ea50b9d486 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,7 @@ extra_quality_checks: python utils/check_dummies.py python utils/check_repo.py python utils/check_inits.py + python utils/tests_fetcher.py --sanity_check # this target runs checks on all files quality: diff --git a/setup.py b/setup.py index 56f46b4885f..d19882e4a6a 100644 --- a/setup.py +++ b/setup.py @@ -100,6 +100,7 @@ _deps = [ "flake8>=3.8.3", "flax>=0.3.4", "fugashi>=1.0", + "GitPython", "huggingface-hub==0.0.12", "importlib_metadata", "ipadic>=1.0.0,<2.0", @@ -259,7 +260,7 @@ extras["codecarbon"] = deps_list("codecarbon") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["testing"] = ( deps_list( - "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black", "sacrebleu", "rouge-score", "nltk" + "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-timeout", "black", "sacrebleu", "rouge-score", "nltk", "GitPython" ) + extras["retrieval"] + extras["modelcreation"] diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 681b985f534..97ea67e6033 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -17,6 +17,7 @@ deps = { "flake8": "flake8>=3.8.3", "flax": "flax>=0.3.4", "fugashi": "fugashi>=1.0", + "GitPython": "GitPython", "huggingface-hub": "huggingface-hub==0.0.12", "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", diff --git a/tests/conftest.py b/tests/conftest.py index 7c5f161436d..5dc776e2227 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,3 +53,9 @@ def pytest_terminal_summary(terminalreporter): make_reports = terminalreporter.config.getoption("--make-reports") if make_reports: pytest_terminal_summary_main(terminalreporter, id=make_reports) + + +def pytest_sessionfinish(session, exitstatus): + # If no tests are collected, pytest exists with code 5, which makes the CI fail. + if exitstatus == 5: + session.exitstatus = 0 diff --git a/utils/style_doc.py b/utils/style_doc.py index 82341a07c41..85113a9fd08 100644 --- a/utils/style_doc.py +++ b/utils/style_doc.py @@ -489,12 +489,14 @@ def style_file_docstrings(code_file, max_len=119, check_only=False): """Style all docstrings in `code_file` to `max_len`.""" with open(code_file, "r", encoding="utf-8", newline="\n") as f: code = f.read() - splits = code.split('"""') + # fmt: off + splits = code.split('\"\"\"') splits = [ (s if i % 2 == 0 or _re_doc_ignore.search(splits[i - 1]) is not None else style_docstring(s, max_len=max_len)) for i, s in enumerate(splits) ] - clean_code = '"""'.join(splits) + clean_code = '\"\"\"'.join(splits) + # fmt: on diff = clean_code != code if not check_only and diff: diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py new file mode 100644 index 00000000000..8fd9b0e372c --- /dev/null +++ b/utils/tests_fetcher.py @@ -0,0 +1,427 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import collections +import os +import re +from contextlib import contextmanager +from pathlib import Path + +from git import Repo + + +# This script is intended to be run from the root of the repo but you can adapt this constant if you need to. +PATH_TO_TRANFORMERS = "." + + +@contextmanager +def checkout_commit(repo, commit_id): + """ + Context manager that checks out a commit in the repo. + """ + current_head = repo.head.commit if repo.head.is_detached else repo.head.ref + + try: + repo.git.checkout(commit_id) + yield + + finally: + repo.git.checkout(current_head) + + +def clean_code(content): + """ + Remove docstrings, empty line or comments from `content`. + """ + # fmt: off + # Remove docstrings by splitting on triple " then triple ': + splits = content.split('\"\"\"') + content = "".join(splits[::2]) + splits = content.split("\'\'\'") + # fmt: on + content = "".join(splits[::2]) + + # Remove empty lines and comments + lines_to_keep = [] + for line in content.split("\n"): + # remove anything that is after a # sign. + line = re.sub("#.*$", "", line) + if len(line) == 0 or line.isspace(): + continue + lines_to_keep.append(line) + return "\n".join(lines_to_keep) + + +def diff_is_docstring_only(repo, branching_point, filename): + """ + Check if the diff is only in docstrings in a filename. + """ + with checkout_commit(repo, branching_point): + with open(filename, "r", encoding="utf-8") as f: + old_content = f.read() + + with open(filename, "r", encoding="utf-8") as f: + new_content = f.read() + + old_content_clean = clean_code(old_content) + new_content_clean = clean_code(new_content) + + return old_content_clean == new_content_clean + + +def get_modified_python_files(): + """ + Return a list of python files that have been modified between the current head and the master branch. + """ + repo = Repo(PATH_TO_TRANFORMERS) + + print(f"Master is at {repo.refs.master.commit}") + print(f"Current head is at {repo.head.commit}") + + branching_commits = repo.merge_base(repo.refs.master, repo.head) + for commit in branching_commits: + print(f"Branching commit: {commit}") + + print("\n### DIFF ###\n") + code_diff = [] + for commit in branching_commits: + for diff_obj in commit.diff(repo.head.commit): + # We always add new python files + if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"): + code_diff.append(diff_obj.b_path) + # We check that deleted python files won't break corresponding tests. + elif diff_obj.change_type == "D" and diff_obj.a_path.endswith(".py"): + code_diff.append(diff_obj.a_path) + # Now for modified files + elif diff_obj.change_type in ["M", "R"] and diff_obj.b_path.endswith(".py"): + # In case of renames, we'll look at the tests using both the old and new name. + if diff_obj.a_path != diff_obj.b_path: + code_diff.extend([diff_obj.a_path, diff_obj.b_path]) + else: + # Otherwise, we check modifications are in code and not docstrings. + if diff_is_docstring_only(repo, commit, diff_obj.b_path): + print(f"Ignoring diff in {diff_obj.b_path} as it only concerns docstrings or comments.") + else: + code_diff.append(diff_obj.a_path) + + return code_diff + + +def get_module_dependencies(module_fname): + """ + Get the dependencies of a module. + """ + with open(os.path.join(PATH_TO_TRANFORMERS, module_fname), "r", encoding="utf-8") as f: + content = f.read() + + module_parts = module_fname.split(os.path.sep) + imported_modules = [] + + # Let's start with relative imports + relative_imports = re.findall(r"from\s+(\.+\S+)\s+import\s+\S+\s", content) + for imp in relative_imports: + level = 0 + while imp.startswith("."): + imp = imp[1:] + level += 1 + + if len(imp) > 0: + dep_parts = module_parts[: len(module_parts) - level] + imp.split(".") + else: + dep_parts = module_parts[: len(module_parts) - level] + ["__init__.py"] + imported_module = os.path.sep.join(dep_parts) + # We ignore the main init import as it's only for the __version__ that it's done + # and it would add everything as a dependency. + if not imported_module.endswith("transformers/__init__.py"): + imported_modules.append(imported_module) + + # Let's continue with direct imports + # The import from the transformers module are ignored for the same reason we ignored the + # main init before. + direct_imports = re.findall(r"from\s+transformers\.(\S+)\s+import\s+\S+\s", content) + for imp in direct_imports: + import_parts = imp.split(".") + dep_parts = ["src", "transformers"] + import_parts + imported_modules.append(os.path.sep.join(dep_parts)) + + # Now let's just check that we have proper module files, or append an init for submodules + dependencies = [] + for imported_module in imported_modules: + if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f"{imported_module}.py")): + dependencies.append(f"{imported_module}.py") + elif os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, imported_module)) and os.path.isfile( + os.path.sep.join([PATH_TO_TRANFORMERS, imported_module, "__init__.py"]) + ): + dependencies.append(os.path.sep.join([imported_module, "__init__.py"])) + return dependencies + + +def create_reverse_dependency_map(): + """ + Create the dependency map from module filename to the list of modules that depend on it (even recursively). + """ + modules = [ + str(f.relative_to(PATH_TO_TRANFORMERS)) + for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py") + ] + # We grab all the dependencies of each module. + direct_deps = {m: get_module_dependencies(m) for m in modules} + + # This recurses the dependencies + something_changed = True + while something_changed: + something_changed = False + for m in modules: + for d in direct_deps[m]: + for dep in direct_deps[d]: + if dep not in direct_deps[m]: + direct_deps[m].append(dep) + something_changed = True + + # Finally we can build the reverse map. + reverse_map = collections.defaultdict(list) + for m in modules: + for d in direct_deps[m]: + reverse_map[d].append(m) + + return reverse_map + + +# Any module file that has a test name which can't be inferred automatically from its name should go here. A better +# approach is to (re-)name the test file accordingly, and second best to add the correspondence map here. +SPECIAL_MODULE_TO_TEST_MAP = { + "configuration_utils.py": "test_configuration_common.py", + "convert_graph_to_onnx.py": "test_onnx.py", + "data/data_collator.py": "test_data_collator.py", + "deepspeed.py": "deepspeed/", + "feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py", + "feature_extraction_utils.py": "test_feature_extraction_common.py", + "file_utils.py": ["test_file_utils.py", "test_model_output.py"], + "modelcard.py": "test_model_card.py", + "modeling_flax_utils.py": "test_modeling_flax_common.py", + "modeling_tf_utils.py": "test_modeling_tf_common.py", + "modeling_utils.py": ["test_modeling_common.py", "test_offline.py"], + "models/auto/modeling_auto.py": ["test_modeling_auto.py", "test_modeling_tf_pytorch.py", "test_modeling_bort.py"], + "models/auto/modeling_flax_auto.py": "test_flax_auto.py", + "models/auto/modeling_tf_auto.py": [ + "test_modeling_tf_auto.py", + "test_modeling_tf_pytorch.py", + "test_modeling_tf_bort.py", + ], + "models/blenderbot_small/tokenization_blenderbot_small.py": "test_tokenization_small_blenderbot.py", + "models/blenderbot_small/tokenization_blenderbot_small_fast.py": "test_tokenization_small_blenderbot.py", + "models/gpt2/modeling_gpt2.py": ["test_modeling_gpt2.py", "test_modeling_megatron_gpt2.py"], + "pipelines/base.py": "test_pipelines_common.py", + "pipelines/text2text_generation.py": [ + "test_pipelines_text2text_generation.py", + "test_pipelines_summarization.py", + "test_pipelines_translation.py", + ], + "pipelines/zero_shot_classification.py": "test_pipelines_zero_shot.py", + "testing_utils.py": "test_skip_decorators.py", + "tokenization_utils.py": "test_tokenization_common.py", + "tokenization_utils_base.py": "test_tokenization_common.py", + "tokenization_utils_fast.py": "test_tokenization_fast.py", + "trainer.py": [ + "test_trainer.py", + "extended/test_trainer_ext.py", + "test_trainer_distributed.py", + "test_trainer_tpu.py", + ], + "utils/versions.py": "test_versions_utils.py", +} + + +def module_to_test_file(module_fname): + """ + Returns the name of the file(s) where `module_fname` is tested. + """ + splits = module_fname.split(os.path.sep) + + # Special map has priority + short_name = os.path.sep.join(splits[2:]) + if short_name in SPECIAL_MODULE_TO_TEST_MAP: + test_file = SPECIAL_MODULE_TO_TEST_MAP[short_name] + if isinstance(test_file, str): + return f"tests/{test_file}" + return [f"tests/{f}" for f in test_file] + + module_name = splits[-1] + # Fast tokenizers are tested in the same file as the slow ones. + if module_name.endswith("_fast.py"): + module_name = module_name.replace("_fast.py", ".py") + + # Special case for pipelines submodules + if len(splits) >= 2 and splits[-2] == "pipelines": + default_test_file = f"tests/test_pipelines_{module_name}" + # Special case for benchmarks submodules + elif len(splits) >= 2 and splits[-2] == "benchmark": + return ["tests/test_benchmark.py", "tests/test_benchmark_tf.py"] + # Special case for commands submodules + elif len(splits) >= 2 and splits[-2] == "commands": + return "tests/test_cli.py" + # Special case for onnx submodules + elif len(splits) >= 2 and splits[-2] == "onnx": + return ["tests/test_onnx.py", "tests/test_onnx_v2.py"] + # Special case for utils (not the one in src/transformers, the ones at the root of the repo). + elif len(splits) > 0 and splits[0] == "utils": + default_test_file = f"tests/test_utils_{module_name}" + else: + default_test_file = f"tests/test_{module_name}" + + if os.path.isfile(default_test_file): + return default_test_file + + # Processing -> processor + if "processing" in default_test_file: + test_file = default_test_file.replace("processing", "processor") + if os.path.isfile(test_file): + return test_file + + +# This list contains the list of test files we expect never to be launched from a change in a module/util. Those are +# launched separately. +EXPECTED_TEST_FILES_NEVER_TOUCHED = [ + "tests/test_doc_samples.py", # Doc tests + "tests/sagemaker/test_single_node_gpu.py", # SageMaker test + "tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test + "tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test +] + + +def _print_list(l): + return "\n".join([f"- {f}" for f in l]) + + +def sanity_check(): + """ + Checks that all test files can be touched by a modification in at least one module/utils. This test ensures that + newly-added test files are properly mapped to some module or utils, so they can be run by the CI. + """ + # Grab all module and utils + all_files = [ + str(p.relative_to(PATH_TO_TRANFORMERS)) + for p in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py") + ] + all_files += [ + str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "utils").glob("**/*.py") + ] + + # Compute all the test files we get from those. + test_files_found = [] + for f in all_files: + test_f = module_to_test_file(f) + if test_f is not None: + if isinstance(test_f, str): + test_files_found.append(test_f) + else: + test_files_found.extend(test_f) + + # Some of the test files might actually be subfolders so we grab the tests inside. + test_files = [] + for test_f in test_files_found: + if os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, test_f)): + test_files.extend( + [ + str(p.relative_to(PATH_TO_TRANFORMERS)) + for p in (Path(PATH_TO_TRANFORMERS) / test_f).glob("**/test*.py") + ] + ) + else: + test_files.append(test_f) + + # Compare to existing test files + existing_test_files = [ + str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/test*.py") + ] + not_touched_test_files = [f for f in existing_test_files if f not in test_files] + + should_be_tested = set(not_touched_test_files) - set(EXPECTED_TEST_FILES_NEVER_TOUCHED) + if len(should_be_tested) > 0: + raise ValueError( + "The following test files are not currently associated with any module or utils files, which means they " + f"will never get run by the CI:\n{_print_list(should_be_tested)}\n. Make sure the names of these test " + "files match the name of the module or utils they are testing, or adapt the constant " + "`SPECIAL_MODULE_TO_TEST_MAP` in `utils/tests_fetcher.py` to add them. If your test file is triggered " + "separately and is not supposed to be run by the regular CI, add it to the " + "`EXPECTED_TEST_FILES_NEVER_TOUCHED` constant instead." + ) + + +def infer_tests_to_run(output_file): + modified_files = get_modified_python_files() + print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}") + + # Create the map that will give us all impacted modules. + impacted_modules_map = create_reverse_dependency_map() + impacted_files = modified_files.copy() + for f in modified_files: + if f in impacted_modules_map: + impacted_files.extend(impacted_modules_map[f]) + + # Remove duplicates + impacted_files = sorted(list(set(impacted_files))) + print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}") + + # Grab the corresponding test files: + test_files_to_run = [] + for f in impacted_files: + # Modified test files are always added + if f.startswith("tests/"): + test_files_to_run.append(f) + else: + new_tests = module_to_test_file(f) + if new_tests is not None: + if isinstance(new_tests, str): + test_files_to_run.append(new_tests) + else: + test_files_to_run.extend(new_tests) + + # Remove duplicates + test_files_to_run = sorted(list(set(test_files_to_run))) + print(f"\n### TEST TO RUN ###\n{_print_list(test_files_to_run)}") + if len(test_files_to_run) > 0: + with open(output_file, "w", encoding="utf-8") as f: + f.write(" ".join(test_files_to_run)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--sanity_check", action="store_true", help="Only test that all tests and modules are accounted for." + ) + parser.add_argument( + "--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run" + ) + args = parser.parse_args() + if args.sanity_check: + sanity_check() + else: + repo = Repo(PATH_TO_TRANFORMERS) + # For now we run all tests on the master branch. After testing this more and making sure it works most of the + # time, we will apply the same logic to the tests on the master branch and only run the whole suite once per + # day. + if not repo.head.is_detached and repo.head.ref == repo.refs.master: + print("Master branch detected, running all tests.") + with open(args.output_file, "w", encoding="utf-8") as f: + f.write("./tests/") + else: + try: + infer_tests_to_run(args.output_file) + except Exception as e: + print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.") + with open(args.output_file, "w", encoding="utf-8") as f: + f.write("./tests/")