mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Test fetch v2 (#22367)
* Test fetcher v2 * Fix regexes * Remove sanity check * Fake modification to OPT * Fixes some .sep issues * Remove fake OPT change * Fake modif for BERT * Fake modif for init * Exclude SageMaker tests * Fix test and remove fake modif * Fake setup modif * Fake pipeline modif * Remove all fake modifs * Adds options to skip/force tests * [test-all-models] Fake modif for BERT * Try this way * Does the command actually work? * [test-all-models] Try again! * [skip circleci] Remove fake modif * Remove debug statements * Add the list of important models * Quality * Update utils/tests_fetcher.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> * Address review comments * Address review comments * Fix and add test * Apply suggestions from code review Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Address review comments --------- Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
parent
3a9464bd30
commit
c612628045
@ -176,7 +176,6 @@ jobs:
|
|||||||
- run: python utils/check_config_attributes.py
|
- run: python utils/check_config_attributes.py
|
||||||
- run: python utils/check_doctest_list.py
|
- run: python utils/check_doctest_list.py
|
||||||
- run: make deps_table_check_updated
|
- run: make deps_table_check_updated
|
||||||
- run: python utils/tests_fetcher.py --sanity_check
|
|
||||||
- run: python utils/update_metadata.py --check-only
|
- run: python utils/update_metadata.py --check-only
|
||||||
- run: python utils/check_task_guides.py
|
- run: python utils/check_task_guides.py
|
||||||
|
|
||||||
|
1
Makefile
1
Makefile
@ -41,7 +41,6 @@ repo-consistency:
|
|||||||
python utils/check_config_docstrings.py
|
python utils/check_config_docstrings.py
|
||||||
python utils/check_config_attributes.py
|
python utils/check_config_attributes.py
|
||||||
python utils/check_doctest_list.py
|
python utils/check_doctest_list.py
|
||||||
python utils/tests_fetcher.py --sanity_check
|
|
||||||
python utils/update_metadata.py --check-only
|
python utils/update_metadata.py --check-only
|
||||||
python utils/check_task_guides.py
|
python utils/check_task_guides.py
|
||||||
|
|
||||||
|
@ -13,52 +13,661 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from git import Repo
|
from git import Repo
|
||||||
|
|
||||||
|
from transformers.testing_utils import CaptureStdout
|
||||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
||||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
|
||||||
|
|
||||||
transformers_path = os.path.join(git_repo_path, "src", "transformers")
|
|
||||||
# Tests are run against this specific commit for reproducibility
|
|
||||||
# https://github.com/huggingface/transformers/tree/07f6690206e39ed7a4d9dbc58824314f7089bb38
|
|
||||||
GIT_TEST_SHA = "07f6690206e39ed7a4d9dbc58824314f7089bb38"
|
|
||||||
|
|
||||||
from tests_fetcher import checkout_commit, clean_code, get_module_dependencies # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
class CheckDummiesTester(unittest.TestCase):
|
REPO_PATH = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
|
sys.path.append(os.path.join(REPO_PATH, "utils"))
|
||||||
|
|
||||||
|
import tests_fetcher # noqa: E402
|
||||||
|
from tests_fetcher import ( # noqa: E402
|
||||||
|
checkout_commit,
|
||||||
|
clean_code,
|
||||||
|
create_module_to_test_map,
|
||||||
|
create_reverse_dependency_map,
|
||||||
|
create_reverse_dependency_tree,
|
||||||
|
diff_is_docstring_only,
|
||||||
|
extract_imports,
|
||||||
|
get_all_tests,
|
||||||
|
get_diff,
|
||||||
|
get_module_dependencies,
|
||||||
|
get_tree_starting_at,
|
||||||
|
infer_tests_to_run,
|
||||||
|
parse_commit_message,
|
||||||
|
print_tree_deps_of,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BERT_MODELING_FILE = "src/transformers/models/bert/modeling_bert.py"
|
||||||
|
BERT_MODEL_FILE = """from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils import is_torch_available
|
||||||
|
from .configuration_bert import BertConfig
|
||||||
|
|
||||||
|
class BertModel:
|
||||||
|
'''
|
||||||
|
This is the docstring.
|
||||||
|
'''
|
||||||
|
This is the code
|
||||||
|
"""
|
||||||
|
|
||||||
|
BERT_MODEL_FILE_NEW_DOCSTRING = """from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils import is_torch_available
|
||||||
|
from .configuration_bert import BertConfig
|
||||||
|
|
||||||
|
class BertModel:
|
||||||
|
'''
|
||||||
|
This is the docstring. It has been updated.
|
||||||
|
'''
|
||||||
|
This is the code
|
||||||
|
"""
|
||||||
|
|
||||||
|
BERT_MODEL_FILE_NEW_CODE = """from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils import is_torch_available
|
||||||
|
from .configuration_bert import BertConfig
|
||||||
|
|
||||||
|
class BertModel:
|
||||||
|
'''
|
||||||
|
This is the docstring.
|
||||||
|
'''
|
||||||
|
This is the code. It has been updated
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def create_tmp_repo(tmp_dir, models=None):
|
||||||
|
"""
|
||||||
|
Creates a repository in a temporary directory mimicking the structure of Transformers. Uses the list of models
|
||||||
|
provided (which defaults to just `["bert"]`).
|
||||||
|
"""
|
||||||
|
tmp_dir = Path(tmp_dir)
|
||||||
|
if tmp_dir.exists():
|
||||||
|
shutil.rmtree(tmp_dir)
|
||||||
|
tmp_dir.mkdir(exist_ok=True)
|
||||||
|
repo = Repo.init(tmp_dir)
|
||||||
|
|
||||||
|
if models is None:
|
||||||
|
models = ["bert"]
|
||||||
|
class_names = [model[0].upper() + model[1:] for model in models]
|
||||||
|
|
||||||
|
transformers_dir = tmp_dir / "src" / "transformers"
|
||||||
|
transformers_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(transformers_dir / "__init__.py", "w") as f:
|
||||||
|
init_lines = ["from .utils import cached_file, is_torch_available"]
|
||||||
|
init_lines.extend(
|
||||||
|
[f"from .models.{model} import {cls}Config, {cls}Model" for model, cls in zip(models, class_names)]
|
||||||
|
)
|
||||||
|
f.write("\n".join(init_lines) + "\n")
|
||||||
|
with open(transformers_dir / "configuration_utils.py", "w") as f:
|
||||||
|
f.write("from .utils import cached_file\n\ncode")
|
||||||
|
with open(transformers_dir / "modeling_utils.py", "w") as f:
|
||||||
|
f.write("from .utils import cached_file\n\ncode")
|
||||||
|
|
||||||
|
utils_dir = tmp_dir / "src" / "transformers" / "utils"
|
||||||
|
utils_dir.mkdir(exist_ok=True)
|
||||||
|
with open(utils_dir / "__init__.py", "w") as f:
|
||||||
|
f.write("from .hub import cached_file\nfrom .imports import is_torch_available\n")
|
||||||
|
with open(utils_dir / "hub.py", "w") as f:
|
||||||
|
f.write("import huggingface_hub\n\ncode")
|
||||||
|
with open(utils_dir / "imports.py", "w") as f:
|
||||||
|
f.write("code")
|
||||||
|
|
||||||
|
model_dir = tmp_dir / "src" / "transformers" / "models"
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(model_dir / "__init__.py", "w") as f:
|
||||||
|
f.write("\n".join([f"import {model}" for model in models]))
|
||||||
|
|
||||||
|
for model, cls in zip(models, class_names):
|
||||||
|
model_dir = tmp_dir / "src" / "transformers" / "models" / model
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(model_dir / "__init__.py", "w") as f:
|
||||||
|
f.write(f"from .configuration_{model} import {cls}Config\nfrom .modeling_{model} import {cls}Model\n")
|
||||||
|
with open(model_dir / f"configuration_{model}.py", "w") as f:
|
||||||
|
f.write("from ...configuration_utils import PretrainedConfig\ncode")
|
||||||
|
with open(model_dir / f"modeling_{model}.py", "w") as f:
|
||||||
|
modeling_code = BERT_MODEL_FILE.replace("bert", model).replace("Bert", cls)
|
||||||
|
f.write(modeling_code)
|
||||||
|
|
||||||
|
test_dir = tmp_dir / "tests"
|
||||||
|
test_dir.mkdir(exist_ok=True)
|
||||||
|
with open(test_dir / "test_modeling_common.py", "w") as f:
|
||||||
|
f.write("from transformers.modeling_utils import PreTrainedModel\ncode")
|
||||||
|
|
||||||
|
for model, cls in zip(models, class_names):
|
||||||
|
test_model_dir = test_dir / "models" / model
|
||||||
|
test_model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(test_model_dir / "__init__.py").touch()
|
||||||
|
with open(test_model_dir / f"test_modeling_{model}.py", "w") as f:
|
||||||
|
f.write(
|
||||||
|
f"from transformers import {cls}Config, {cls}Model\nfrom ...test_modeling_common import ModelTesterMixin\n\ncode"
|
||||||
|
)
|
||||||
|
|
||||||
|
repo.index.add(["src", "tests"])
|
||||||
|
repo.index.commit("Initial commit")
|
||||||
|
repo.create_head("main")
|
||||||
|
repo.head.reference = repo.refs.main
|
||||||
|
repo.delete_head("master")
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_transformer_repo_path(new_folder):
|
||||||
|
"""
|
||||||
|
Temporarily patches the variables defines in `tests_fetcher` to use a different location for the repo.
|
||||||
|
"""
|
||||||
|
old_repo_path = tests_fetcher.PATH_TO_REPO
|
||||||
|
tests_fetcher.PATH_TO_REPO = Path(new_folder).resolve()
|
||||||
|
tests_fetcher.PATH_TO_TRANFORMERS = tests_fetcher.PATH_TO_REPO / "src/transformers"
|
||||||
|
tests_fetcher.PATH_TO_TESTS = tests_fetcher.PATH_TO_REPO / "tests"
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
tests_fetcher.PATH_TO_REPO = old_repo_path
|
||||||
|
tests_fetcher.PATH_TO_TRANFORMERS = tests_fetcher.PATH_TO_REPO / "src/transformers"
|
||||||
|
tests_fetcher.PATH_TO_TESTS = tests_fetcher.PATH_TO_REPO / "tests"
|
||||||
|
|
||||||
|
|
||||||
|
def commit_changes(filenames, contents, repo, commit_message="Commit"):
|
||||||
|
"""
|
||||||
|
Commit new `contents` to `filenames` inside a given `repo`.
|
||||||
|
"""
|
||||||
|
if not isinstance(filenames, list):
|
||||||
|
filenames = [filenames]
|
||||||
|
if not isinstance(contents, list):
|
||||||
|
contents = [contents]
|
||||||
|
|
||||||
|
folder = Path(repo.working_dir)
|
||||||
|
for filename, content in zip(filenames, contents):
|
||||||
|
with open(folder / filename, "w") as f:
|
||||||
|
f.write(content)
|
||||||
|
repo.index.add(filenames)
|
||||||
|
commit = repo.index.commit(commit_message)
|
||||||
|
return commit.hexsha
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetcherTester(unittest.TestCase):
|
||||||
|
def test_checkout_commit(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
repo = create_tmp_repo(tmp_folder)
|
||||||
|
initial_sha = repo.head.commit.hexsha
|
||||||
|
new_sha = commit_changes(BERT_MODELING_FILE, BERT_MODEL_FILE_NEW_DOCSTRING, repo)
|
||||||
|
|
||||||
|
assert repo.head.commit.hexsha == new_sha
|
||||||
|
with checkout_commit(repo, initial_sha):
|
||||||
|
assert repo.head.commit.hexsha == initial_sha
|
||||||
|
with open(tmp_folder / BERT_MODELING_FILE) as f:
|
||||||
|
assert f.read() == BERT_MODEL_FILE
|
||||||
|
|
||||||
|
assert repo.head.commit.hexsha == new_sha
|
||||||
|
with open(tmp_folder / BERT_MODELING_FILE) as f:
|
||||||
|
assert f.read() == BERT_MODEL_FILE_NEW_DOCSTRING
|
||||||
|
|
||||||
def test_clean_code(self):
|
def test_clean_code(self):
|
||||||
# Clean code removes all strings in triple quotes
|
# Clean code removes all strings in triple quotes
|
||||||
self.assertEqual(clean_code('"""\nDocstring\n"""\ncode\n"""Long string"""\ncode\n'), "code\ncode")
|
assert clean_code('"""\nDocstring\n"""\ncode\n"""Long string"""\ncode\n') == "code\ncode"
|
||||||
self.assertEqual(clean_code("'''\nDocstring\n'''\ncode\n'''Long string'''\ncode\n'''"), "code\ncode")
|
assert clean_code("'''\nDocstring\n'''\ncode\n'''Long string'''\ncode\n'''") == "code\ncode"
|
||||||
|
|
||||||
# Clean code removes all comments
|
# Clean code removes all comments
|
||||||
self.assertEqual(clean_code("code\n# Comment\ncode"), "code\ncode")
|
assert clean_code("code\n# Comment\ncode") == "code\ncode"
|
||||||
self.assertEqual(clean_code("code # inline comment\ncode"), "code \ncode")
|
assert clean_code("code # inline comment\ncode") == "code \ncode"
|
||||||
|
|
||||||
def test_checkout_commit(self):
|
def test_get_all_tests(self):
|
||||||
repo = Repo(git_repo_path)
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
self.assertNotEqual(repo.head.commit.hexsha, GIT_TEST_SHA)
|
tmp_folder = Path(tmp_folder)
|
||||||
with checkout_commit(repo, GIT_TEST_SHA):
|
create_tmp_repo(tmp_folder)
|
||||||
self.assertEqual(repo.head.commit.hexsha, GIT_TEST_SHA)
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
self.assertNotEqual(repo.head.commit.hexsha, GIT_TEST_SHA)
|
assert get_all_tests() == ["tests/models/bert", "tests/test_modeling_common.py"]
|
||||||
|
|
||||||
|
def test_get_all_tests_on_full_repo(self):
|
||||||
|
all_tests = get_all_tests()
|
||||||
|
assert "tests/models/albert" in all_tests
|
||||||
|
assert "tests/models/bert" in all_tests
|
||||||
|
assert "tests/repo_utils" in all_tests
|
||||||
|
assert "tests/test_pipeline_mixin.py" in all_tests
|
||||||
|
assert "tests/models" not in all_tests
|
||||||
|
assert "tests/__pycache__" not in all_tests
|
||||||
|
assert "tests/models/albert/test_modeling_albert.py" not in all_tests
|
||||||
|
assert "tests/repo_utils/test_tests_fetcher.py" not in all_tests
|
||||||
|
|
||||||
|
def test_diff_is_docstring_only(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
repo = create_tmp_repo(tmp_folder)
|
||||||
|
|
||||||
|
branching_point = repo.refs.main.commit
|
||||||
|
bert_file = BERT_MODELING_FILE
|
||||||
|
commit_changes(bert_file, BERT_MODEL_FILE_NEW_DOCSTRING, repo)
|
||||||
|
assert diff_is_docstring_only(repo, branching_point, bert_file)
|
||||||
|
|
||||||
|
commit_changes(bert_file, BERT_MODEL_FILE_NEW_CODE, repo)
|
||||||
|
assert not diff_is_docstring_only(repo, branching_point, bert_file)
|
||||||
|
|
||||||
|
def test_get_diff(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
repo = create_tmp_repo(tmp_folder)
|
||||||
|
|
||||||
|
initial_commit = repo.refs.main.commit
|
||||||
|
bert_file = BERT_MODELING_FILE
|
||||||
|
commit_changes(bert_file, BERT_MODEL_FILE_NEW_DOCSTRING, repo)
|
||||||
|
assert get_diff(repo, repo.head.commit, repo.head.commit.parents) == []
|
||||||
|
|
||||||
|
commit_changes(bert_file, BERT_MODEL_FILE_NEW_DOCSTRING + "\n# Adding a comment\n", repo)
|
||||||
|
assert get_diff(repo, repo.head.commit, repo.head.commit.parents) == []
|
||||||
|
|
||||||
|
commit_changes(bert_file, BERT_MODEL_FILE_NEW_CODE, repo)
|
||||||
|
assert get_diff(repo, repo.head.commit, repo.head.commit.parents) == [
|
||||||
|
"src/transformers/models/bert/modeling_bert.py"
|
||||||
|
]
|
||||||
|
|
||||||
|
commit_changes("src/transformers/utils/hub.py", "import huggingface_hub\n\nnew code", repo)
|
||||||
|
assert get_diff(repo, repo.head.commit, repo.head.commit.parents) == ["src/transformers/utils/hub.py"]
|
||||||
|
assert get_diff(repo, repo.head.commit, [initial_commit]) == [
|
||||||
|
"src/transformers/models/bert/modeling_bert.py",
|
||||||
|
"src/transformers/utils/hub.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_extract_imports_relative(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
create_tmp_repo(tmp_folder)
|
||||||
|
|
||||||
|
expected_bert_imports = [
|
||||||
|
("src/transformers/modeling_utils.py", ["PreTrainedModel"]),
|
||||||
|
("src/transformers/utils/__init__.py", ["is_torch_available"]),
|
||||||
|
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||||
|
]
|
||||||
|
expected_utils_imports = [
|
||||||
|
("src/transformers/utils/hub.py", ["cached_file"]),
|
||||||
|
("src/transformers/utils/imports.py", ["is_torch_available"]),
|
||||||
|
]
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||||
|
assert extract_imports("src/transformers/utils/__init__.py") == expected_utils_imports
|
||||||
|
|
||||||
|
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||||
|
f.write(
|
||||||
|
"from ...utils import cached_file, is_torch_available\nfrom .configuration_bert import BertConfig\n"
|
||||||
|
)
|
||||||
|
expected_bert_imports = [
|
||||||
|
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||||
|
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||||
|
]
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||||
|
|
||||||
|
# Test with multi-line imports
|
||||||
|
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||||
|
f.write(
|
||||||
|
"from ...utils import (\n cached_file,\n is_torch_available\n)\nfrom .configuration_bert import BertConfig\n"
|
||||||
|
)
|
||||||
|
expected_bert_imports = [
|
||||||
|
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||||
|
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||||
|
]
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||||
|
|
||||||
|
def test_extract_imports_absolute(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
create_tmp_repo(tmp_folder)
|
||||||
|
|
||||||
|
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||||
|
f.write(
|
||||||
|
"from transformers.utils import cached_file, is_torch_available\nfrom transformers.models.bert.configuration_bert import BertConfig\n"
|
||||||
|
)
|
||||||
|
expected_bert_imports = [
|
||||||
|
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||||
|
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||||
|
]
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||||
|
|
||||||
|
# Test with multi-line imports
|
||||||
|
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||||
|
f.write(
|
||||||
|
"from transformers.utils import (\n cached_file,\n is_torch_available\n)\nfrom transformers.models.bert.configuration_bert import BertConfig\n"
|
||||||
|
)
|
||||||
|
expected_bert_imports = [
|
||||||
|
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||||
|
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||||
|
]
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||||
|
|
||||||
|
# Test with base imports
|
||||||
|
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||||
|
f.write(
|
||||||
|
"from transformers.utils import (\n cached_file,\n is_torch_available\n)\nfrom transformers import BertConfig\n"
|
||||||
|
)
|
||||||
|
expected_bert_imports = [
|
||||||
|
("src/transformers/__init__.py", ["BertConfig"]),
|
||||||
|
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||||
|
]
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||||
|
|
||||||
def test_get_module_dependencies(self):
|
def test_get_module_dependencies(self):
|
||||||
bert_module = os.path.join(transformers_path, "models", "bert", "modeling_bert.py")
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
expected_deps = [
|
tmp_folder = Path(tmp_folder)
|
||||||
"activations.py",
|
create_tmp_repo(tmp_folder)
|
||||||
"modeling_outputs.py",
|
|
||||||
"modeling_utils.py",
|
expected_bert_dependencies = [
|
||||||
"pytorch_utils.py",
|
"src/transformers/modeling_utils.py",
|
||||||
"models/bert/configuration_bert.py",
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
]
|
"src/transformers/utils/imports.py",
|
||||||
expected_deps = {os.path.join(transformers_path, f) for f in expected_deps}
|
]
|
||||||
repo = Repo(git_repo_path)
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
with checkout_commit(repo, GIT_TEST_SHA):
|
assert get_module_dependencies(BERT_MODELING_FILE) == expected_bert_dependencies
|
||||||
deps = get_module_dependencies(bert_module)
|
|
||||||
deps = {os.path.expanduser(f) for f in deps}
|
expected_test_bert_dependencies = [
|
||||||
self.assertEqual(deps, expected_deps)
|
"tests/test_modeling_common.py",
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
"src/transformers/models/bert/modeling_bert.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
assert (
|
||||||
|
get_module_dependencies("tests/models/bert/test_modeling_bert.py")
|
||||||
|
== expected_test_bert_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with a submodule
|
||||||
|
(tmp_folder / "src/transformers/utils/logging.py").touch()
|
||||||
|
with open(tmp_folder / BERT_MODELING_FILE, "a") as f:
|
||||||
|
f.write("from ...utils import logging\n")
|
||||||
|
|
||||||
|
expected_bert_dependencies = [
|
||||||
|
"src/transformers/modeling_utils.py",
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
"src/transformers/utils/logging.py",
|
||||||
|
"src/transformers/utils/imports.py",
|
||||||
|
]
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
assert get_module_dependencies(BERT_MODELING_FILE) == expected_bert_dependencies
|
||||||
|
|
||||||
|
# Test with an object non-imported in the init
|
||||||
|
create_tmp_repo(tmp_folder)
|
||||||
|
with open(tmp_folder / BERT_MODELING_FILE, "a") as f:
|
||||||
|
f.write("from ...utils import CONSTANT\n")
|
||||||
|
|
||||||
|
expected_bert_dependencies = [
|
||||||
|
"src/transformers/modeling_utils.py",
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
"src/transformers/utils/__init__.py",
|
||||||
|
"src/transformers/utils/imports.py",
|
||||||
|
]
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
assert get_module_dependencies(BERT_MODELING_FILE) == expected_bert_dependencies
|
||||||
|
|
||||||
|
def test_create_reverse_dependency_tree(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
create_tmp_repo(tmp_folder)
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
tree = create_reverse_dependency_tree()
|
||||||
|
|
||||||
|
init_edges = [
|
||||||
|
"src/transformers/utils/hub.py",
|
||||||
|
"src/transformers/utils/imports.py",
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
"src/transformers/models/bert/modeling_bert.py",
|
||||||
|
]
|
||||||
|
assert {f for f, g in tree if g == "src/transformers/__init__.py"} == set(init_edges)
|
||||||
|
|
||||||
|
bert_edges = [
|
||||||
|
"src/transformers/modeling_utils.py",
|
||||||
|
"src/transformers/utils/imports.py",
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
]
|
||||||
|
assert {f for f, g in tree if g == "src/transformers/models/bert/modeling_bert.py"} == set(bert_edges)
|
||||||
|
|
||||||
|
test_bert_edges = [
|
||||||
|
"tests/test_modeling_common.py",
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
"src/transformers/models/bert/modeling_bert.py",
|
||||||
|
]
|
||||||
|
assert {f for f, g in tree if g == "tests/models/bert/test_modeling_bert.py"} == set(test_bert_edges)
|
||||||
|
|
||||||
|
def test_get_tree_starting_at(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
create_tmp_repo(tmp_folder)
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
edges = create_reverse_dependency_tree()
|
||||||
|
|
||||||
|
bert_tree = get_tree_starting_at("src/transformers/models/bert/modeling_bert.py", edges)
|
||||||
|
config_utils_tree = get_tree_starting_at("src/transformers/configuration_utils.py", edges)
|
||||||
|
|
||||||
|
expected_bert_tree = [
|
||||||
|
"src/transformers/models/bert/modeling_bert.py",
|
||||||
|
[("src/transformers/models/bert/modeling_bert.py", "tests/models/bert/test_modeling_bert.py")],
|
||||||
|
]
|
||||||
|
assert bert_tree == expected_bert_tree
|
||||||
|
|
||||||
|
expected_config_tree = [
|
||||||
|
"src/transformers/configuration_utils.py",
|
||||||
|
[("src/transformers/configuration_utils.py", "src/transformers/models/bert/configuration_bert.py")],
|
||||||
|
[
|
||||||
|
("src/transformers/models/bert/configuration_bert.py", "tests/models/bert/test_modeling_bert.py"),
|
||||||
|
(
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
"src/transformers/models/bert/modeling_bert.py",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
# Order of the edges is random
|
||||||
|
assert [set(v) for v in config_utils_tree] == [set(v) for v in expected_config_tree]
|
||||||
|
|
||||||
|
def test_print_tree_deps_of(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
create_tmp_repo(tmp_folder)
|
||||||
|
|
||||||
|
# There are two possible outputs since the order of the last two lines is non-deterministic.
|
||||||
|
expected_std_out = """src/transformers/models/bert/modeling_bert.py
|
||||||
|
tests/models/bert/test_modeling_bert.py
|
||||||
|
src/transformers/configuration_utils.py
|
||||||
|
src/transformers/models/bert/configuration_bert.py
|
||||||
|
src/transformers/models/bert/modeling_bert.py
|
||||||
|
tests/models/bert/test_modeling_bert.py"""
|
||||||
|
|
||||||
|
expected_std_out_2 = """src/transformers/models/bert/modeling_bert.py
|
||||||
|
tests/models/bert/test_modeling_bert.py
|
||||||
|
src/transformers/configuration_utils.py
|
||||||
|
src/transformers/models/bert/configuration_bert.py
|
||||||
|
tests/models/bert/test_modeling_bert.py
|
||||||
|
src/transformers/models/bert/modeling_bert.py"""
|
||||||
|
|
||||||
|
with patch_transformer_repo_path(tmp_folder), CaptureStdout() as cs:
|
||||||
|
print_tree_deps_of("src/transformers/models/bert/modeling_bert.py")
|
||||||
|
print_tree_deps_of("src/transformers/configuration_utils.py")
|
||||||
|
|
||||||
|
assert cs.out.strip() in [expected_std_out, expected_std_out_2]
|
||||||
|
|
||||||
|
def test_create_reverse_dependency_map(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
create_tmp_repo(tmp_folder)
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
reverse_map = create_reverse_dependency_map()
|
||||||
|
|
||||||
|
# impact of BERT modeling file (note that we stop at the inits and don't go down further)
|
||||||
|
expected_bert_deps = {
|
||||||
|
"src/transformers/__init__.py",
|
||||||
|
"src/transformers/models/bert/__init__.py",
|
||||||
|
"tests/models/bert/test_modeling_bert.py",
|
||||||
|
}
|
||||||
|
assert set(reverse_map["src/transformers/models/bert/modeling_bert.py"]) == expected_bert_deps
|
||||||
|
|
||||||
|
# init gets the direct deps (and their recursive deps)
|
||||||
|
expected_init_deps = {
|
||||||
|
"src/transformers/utils/__init__.py",
|
||||||
|
"src/transformers/utils/hub.py",
|
||||||
|
"src/transformers/utils/imports.py",
|
||||||
|
"src/transformers/models/bert/__init__.py",
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
"src/transformers/models/bert/modeling_bert.py",
|
||||||
|
"src/transformers/configuration_utils.py",
|
||||||
|
"src/transformers/modeling_utils.py",
|
||||||
|
"tests/test_modeling_common.py",
|
||||||
|
"tests/models/bert/test_modeling_bert.py",
|
||||||
|
}
|
||||||
|
assert set(reverse_map["src/transformers/__init__.py"]) == expected_init_deps
|
||||||
|
|
||||||
|
expected_init_deps = {
|
||||||
|
"src/transformers/__init__.py",
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
"src/transformers/models/bert/modeling_bert.py",
|
||||||
|
"tests/models/bert/test_modeling_bert.py",
|
||||||
|
}
|
||||||
|
assert set(reverse_map["src/transformers/models/bert/__init__.py"]) == expected_init_deps
|
||||||
|
|
||||||
|
# Test that with more models init of bert only gets deps to bert.
|
||||||
|
create_tmp_repo(tmp_folder, models=["bert", "gpt2"])
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
reverse_map = create_reverse_dependency_map()
|
||||||
|
|
||||||
|
# init gets the direct deps (and their recursive deps)
|
||||||
|
expected_init_deps = {
|
||||||
|
"src/transformers/__init__.py",
|
||||||
|
"src/transformers/models/bert/configuration_bert.py",
|
||||||
|
"src/transformers/models/bert/modeling_bert.py",
|
||||||
|
"tests/models/bert/test_modeling_bert.py",
|
||||||
|
}
|
||||||
|
assert set(reverse_map["src/transformers/models/bert/__init__.py"]) == expected_init_deps
|
||||||
|
|
||||||
|
def test_create_module_to_test_map(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
models = models = ["bert", "gpt2"] + [f"bert{i}" for i in range(10)]
|
||||||
|
create_tmp_repo(tmp_folder, models=models)
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
test_map = create_module_to_test_map(filter_models=True)
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
assert test_map[f"src/transformers/models/{model}/modeling_{model}.py"] == [
|
||||||
|
f"tests/models/{model}/test_modeling_{model}.py"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Init got filtered
|
||||||
|
expected_init_tests = {
|
||||||
|
"tests/test_modeling_common.py",
|
||||||
|
"tests/models/bert/test_modeling_bert.py",
|
||||||
|
"tests/models/gpt2/test_modeling_gpt2.py",
|
||||||
|
}
|
||||||
|
assert set(test_map["src/transformers/__init__.py"]) == expected_init_tests
|
||||||
|
|
||||||
|
def test_infer_tests_to_run(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
models = ["bert", "gpt2"] + [f"bert{i}" for i in range(10)]
|
||||||
|
repo = create_tmp_repo(tmp_folder, models=models)
|
||||||
|
|
||||||
|
commit_changes("src/transformers/models/bert/modeling_bert.py", BERT_MODEL_FILE_NEW_CODE, repo)
|
||||||
|
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
|
||||||
|
with open(tmp_folder / "test-output.txt", "r") as f:
|
||||||
|
tests_to_run = f.read()
|
||||||
|
|
||||||
|
assert tests_to_run == "tests/models/bert/test_modeling_bert.py"
|
||||||
|
|
||||||
|
# Fake a new model addition
|
||||||
|
repo = create_tmp_repo(tmp_folder, models=models)
|
||||||
|
|
||||||
|
branch = repo.create_head("new_model")
|
||||||
|
branch.checkout()
|
||||||
|
|
||||||
|
with open(tmp_folder / "src/transformers/__init__.py", "a") as f:
|
||||||
|
f.write("from .models.t5 import T5Config, T5Model\n")
|
||||||
|
|
||||||
|
model_dir = tmp_folder / "src/transformers/models/t5"
|
||||||
|
model_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
with open(model_dir / "__init__.py", "w") as f:
|
||||||
|
f.write("from .configuration_t5 import T5Config\nfrom .modeling_t5 import T5Model\n")
|
||||||
|
with open(model_dir / "configuration_t5.py", "w") as f:
|
||||||
|
f.write("from ...configuration_utils import PretrainedConfig\ncode")
|
||||||
|
with open(model_dir / "modeling_t5.py", "w") as f:
|
||||||
|
modeling_code = BERT_MODEL_FILE.replace("bert", "t5").replace("Bert", "T5")
|
||||||
|
f.write(modeling_code)
|
||||||
|
|
||||||
|
test_dir = tmp_folder / "tests/models/t5"
|
||||||
|
test_dir.mkdir(exist_ok=True)
|
||||||
|
(test_dir / "__init__.py").touch()
|
||||||
|
with open(test_dir / "test_modeling_t5.py", "w") as f:
|
||||||
|
f.write(
|
||||||
|
"from transformers import T5Config, T5Model\nfrom ...test_modeling_common import ModelTesterMixin\n\ncode"
|
||||||
|
)
|
||||||
|
|
||||||
|
repo.index.add(["src", "tests"])
|
||||||
|
repo.index.commit("Add T5 model")
|
||||||
|
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
infer_tests_to_run(tmp_folder / "test-output.txt")
|
||||||
|
with open(tmp_folder / "test-output.txt", "r") as f:
|
||||||
|
tests_to_run = f.read()
|
||||||
|
|
||||||
|
expected_tests = {
|
||||||
|
"tests/models/bert/test_modeling_bert.py",
|
||||||
|
"tests/models/gpt2/test_modeling_gpt2.py",
|
||||||
|
"tests/models/t5/test_modeling_t5.py",
|
||||||
|
"tests/test_modeling_common.py",
|
||||||
|
}
|
||||||
|
assert set(tests_to_run.split(" ")) == expected_tests
|
||||||
|
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
infer_tests_to_run(tmp_folder / "test-output.txt", filter_models=False)
|
||||||
|
with open(tmp_folder / "test-output.txt", "r") as f:
|
||||||
|
tests_to_run = f.read()
|
||||||
|
|
||||||
|
expected_tests = [f"tests/models/{name}/test_modeling_{name}.py" for name in models + ["t5"]]
|
||||||
|
expected_tests = set(expected_tests + ["tests/test_modeling_common.py"])
|
||||||
|
assert set(tests_to_run.split(" ")) == expected_tests
|
||||||
|
|
||||||
|
def test_infer_tests_to_run_with_test_modifs(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||||
|
tmp_folder = Path(tmp_folder)
|
||||||
|
models = ["bert", "gpt2"] + [f"bert{i}" for i in range(10)]
|
||||||
|
repo = create_tmp_repo(tmp_folder, models=models)
|
||||||
|
|
||||||
|
commit_changes(
|
||||||
|
"tests/models/bert/test_modeling_bert.py",
|
||||||
|
"from transformers import BertConfig, BertModel\nfrom ...test_modeling_common import ModelTesterMixin\n\ncode1",
|
||||||
|
repo,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch_transformer_repo_path(tmp_folder):
|
||||||
|
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
|
||||||
|
with open(tmp_folder / "test-output.txt", "r") as f:
|
||||||
|
tests_to_run = f.read()
|
||||||
|
|
||||||
|
assert tests_to_run == "tests/models/bert/test_modeling_bert.py"
|
||||||
|
|
||||||
|
def test_parse_commit_message(self):
|
||||||
|
assert parse_commit_message("Normal commit") == {"skip": False, "no_filter": False, "test_all": False}
|
||||||
|
|
||||||
|
assert parse_commit_message("[skip ci] commit") == {"skip": True, "no_filter": False, "test_all": False}
|
||||||
|
assert parse_commit_message("[ci skip] commit") == {"skip": True, "no_filter": False, "test_all": False}
|
||||||
|
assert parse_commit_message("[skip-ci] commit") == {"skip": True, "no_filter": False, "test_all": False}
|
||||||
|
assert parse_commit_message("[skip_ci] commit") == {"skip": True, "no_filter": False, "test_all": False}
|
||||||
|
|
||||||
|
assert parse_commit_message("[no filter] commit") == {"skip": False, "no_filter": True, "test_all": False}
|
||||||
|
assert parse_commit_message("[no-filter] commit") == {"skip": False, "no_filter": True, "test_all": False}
|
||||||
|
assert parse_commit_message("[no_filter] commit") == {"skip": False, "no_filter": True, "test_all": False}
|
||||||
|
assert parse_commit_message("[filter-no] commit") == {"skip": False, "no_filter": True, "test_all": False}
|
||||||
|
|
||||||
|
assert parse_commit_message("[test all] commit") == {"skip": False, "no_filter": False, "test_all": True}
|
||||||
|
assert parse_commit_message("[all test] commit") == {"skip": False, "no_filter": False, "test_all": True}
|
||||||
|
assert parse_commit_message("[test-all] commit") == {"skip": False, "no_filter": False, "test_all": True}
|
||||||
|
assert parse_commit_message("[all_test] commit") == {"skip": False, "no_filter": False, "test_all": True}
|
||||||
|
@ -13,6 +13,27 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Welcome to tests_fetcher V2.
|
||||||
|
This util is designed to fetch tests to run on a PR so that only the tests impacted by the modifications are run, and
|
||||||
|
when too many models are being impacted, only run the tests of a subset of core models. It works like this.
|
||||||
|
|
||||||
|
Stage 1: Identify the modified files. This takes all the files from the branching point to the current commit (so
|
||||||
|
all modifications in a PR, not just the last commit) but excludes modifications that are on docstrings or comments
|
||||||
|
only.
|
||||||
|
|
||||||
|
Stage 2: Extract the tests to run. This is done by looking at the imports in each module and test file: if module A
|
||||||
|
imports module B, then changing module B impacts module A, so the tests using module A should be run. We thus get the
|
||||||
|
dependencies of each model and then recursively builds the 'reverse' map of dependencies to get all modules and tests
|
||||||
|
impacted by a given file. We then only keep the tests (and only the code models tests if there are too many modules).
|
||||||
|
|
||||||
|
Caveats:
|
||||||
|
- This module only filters tests by files (not individual tests) so it's better to have tests for different things
|
||||||
|
in different files.
|
||||||
|
- This module assumes inits are just importing things, not really building objects, so it's better to structure
|
||||||
|
them this way and move objects building in separate submodules.
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import json
|
import json
|
||||||
@ -24,13 +45,36 @@ from pathlib import Path
|
|||||||
from git import Repo
|
from git import Repo
|
||||||
|
|
||||||
|
|
||||||
# This script is intended to be run from the root of the repo but you can adapt this constant if you need to.
|
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
||||||
PATH_TO_TRANFORMERS = "."
|
PATH_TO_TRANFORMERS = PATH_TO_REPO / "src/transformers"
|
||||||
|
PATH_TO_TESTS = PATH_TO_REPO / "tests"
|
||||||
|
|
||||||
# A temporary way to trigger all pipeline tests contained in model test files after PR #21516
|
# List here the models to always test.
|
||||||
all_model_test_files = [str(x) for x in Path("tests/models/").glob("**/**/test_modeling_*.py")]
|
IMPORTANT_MODELS = [
|
||||||
|
# Most downloaded models
|
||||||
all_pipeline_test_files = [str(x) for x in Path("tests/pipelines/").glob("**/test_pipelines_*.py")]
|
"bert",
|
||||||
|
"clip",
|
||||||
|
"t5",
|
||||||
|
"xlm-roberta",
|
||||||
|
"gpt2",
|
||||||
|
"bart",
|
||||||
|
"mpnet",
|
||||||
|
"gpt-j",
|
||||||
|
"wav2vec2",
|
||||||
|
"deberta-v2",
|
||||||
|
"layoutlm",
|
||||||
|
"opt",
|
||||||
|
"longformer",
|
||||||
|
"vit",
|
||||||
|
# Pipeline-specific model (to be sure each pipeline has one model in this list)
|
||||||
|
"tapas",
|
||||||
|
"vilt",
|
||||||
|
"clap",
|
||||||
|
"detr",
|
||||||
|
"owlvit",
|
||||||
|
"dpt",
|
||||||
|
"videomae",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -79,17 +123,21 @@ def get_all_tests():
|
|||||||
- folders under `tests/models`: `bert`, `gpt2`, etc.
|
- folders under `tests/models`: `bert`, `gpt2`, etc.
|
||||||
- test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
|
- test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
|
||||||
"""
|
"""
|
||||||
test_root_dir = os.path.join(PATH_TO_TRANFORMERS, "tests")
|
|
||||||
|
|
||||||
# test folders/files directly under `tests` folder
|
# test folders/files directly under `tests` folder
|
||||||
tests = os.listdir(test_root_dir)
|
tests = os.listdir(PATH_TO_TESTS)
|
||||||
tests = sorted(filter(lambda x: os.path.isdir(x) or x.startswith("tests/test_"), [f"tests/{x}" for x in tests]))
|
tests = [f"tests/{f}" for f in tests if "__pycache__" not in f]
|
||||||
|
tests = sorted([f for f in tests if (PATH_TO_REPO / f).is_dir() or f.startswith("tests/test_")])
|
||||||
|
|
||||||
# model specific test folders
|
# model specific test folders
|
||||||
model_tests_folders = os.listdir(os.path.join(test_root_dir, "models"))
|
model_test_folders = os.listdir(PATH_TO_TESTS / "models")
|
||||||
model_test_folders = sorted(filter(os.path.isdir, [f"tests/models/{x}" for x in model_tests_folders]))
|
model_test_folders = [f"tests/models/{f}" for f in model_test_folders if "__pycache__" not in f]
|
||||||
|
model_test_folders = sorted([f for f in model_test_folders if (PATH_TO_REPO / f).is_dir()])
|
||||||
|
|
||||||
tests.remove("tests/models")
|
tests.remove("tests/models")
|
||||||
|
# Sagemaker tests are not meant to be run on the CI.
|
||||||
|
if "tests/sagemaker" in tests:
|
||||||
|
tests.remove("tests/sagemaker")
|
||||||
tests = model_test_folders + tests
|
tests = model_test_folders + tests
|
||||||
|
|
||||||
return tests
|
return tests
|
||||||
@ -99,11 +147,12 @@ def diff_is_docstring_only(repo, branching_point, filename):
|
|||||||
"""
|
"""
|
||||||
Check if the diff is only in docstrings in a filename.
|
Check if the diff is only in docstrings in a filename.
|
||||||
"""
|
"""
|
||||||
|
folder = Path(repo.working_dir)
|
||||||
with checkout_commit(repo, branching_point):
|
with checkout_commit(repo, branching_point):
|
||||||
with open(filename, "r", encoding="utf-8") as f:
|
with open(folder / filename, "r", encoding="utf-8") as f:
|
||||||
old_content = f.read()
|
old_content = f.read()
|
||||||
|
|
||||||
with open(filename, "r", encoding="utf-8") as f:
|
with open(folder / filename, "r", encoding="utf-8") as f:
|
||||||
new_content = f.read()
|
new_content = f.read()
|
||||||
|
|
||||||
old_content_clean = clean_code(old_content)
|
old_content_clean = clean_code(old_content)
|
||||||
@ -112,31 +161,6 @@ def diff_is_docstring_only(repo, branching_point, filename):
|
|||||||
return old_content_clean == new_content_clean
|
return old_content_clean == new_content_clean
|
||||||
|
|
||||||
|
|
||||||
def get_modified_python_files(diff_with_last_commit=False):
|
|
||||||
"""
|
|
||||||
Return a list of python files that have been modified between:
|
|
||||||
|
|
||||||
- the current head and the main branch if `diff_with_last_commit=False` (default)
|
|
||||||
- the current head and its parent commit otherwise.
|
|
||||||
"""
|
|
||||||
repo = Repo(PATH_TO_TRANFORMERS)
|
|
||||||
|
|
||||||
if not diff_with_last_commit:
|
|
||||||
print(f"main is at {repo.refs.main.commit}")
|
|
||||||
print(f"Current head is at {repo.head.commit}")
|
|
||||||
|
|
||||||
branching_commits = repo.merge_base(repo.refs.main, repo.head)
|
|
||||||
for commit in branching_commits:
|
|
||||||
print(f"Branching commit: {commit}")
|
|
||||||
return get_diff(repo, repo.head.commit, branching_commits)
|
|
||||||
else:
|
|
||||||
print(f"main is at {repo.head.commit}")
|
|
||||||
parent_commits = repo.head.commit.parents
|
|
||||||
for commit in parent_commits:
|
|
||||||
print(f"Parent commit: {commit}")
|
|
||||||
return get_diff(repo, repo.head.commit, parent_commits)
|
|
||||||
|
|
||||||
|
|
||||||
def get_diff(repo, base_commit, commits):
|
def get_diff(repo, base_commit, commits):
|
||||||
"""
|
"""
|
||||||
Get's the diff between one or several commits and the head of the repository.
|
Get's the diff between one or several commits and the head of the repository.
|
||||||
@ -166,96 +190,173 @@ def get_diff(repo, base_commit, commits):
|
|||||||
return code_diff
|
return code_diff
|
||||||
|
|
||||||
|
|
||||||
def get_module_dependencies(module_fname):
|
def get_modified_python_files(diff_with_last_commit=False):
|
||||||
"""
|
"""
|
||||||
Get the dependencies of a module.
|
Return a list of python files that have been modified between:
|
||||||
|
|
||||||
|
- the current head and the main branch if `diff_with_last_commit=False` (default)
|
||||||
|
- the current head and its parent commit otherwise.
|
||||||
"""
|
"""
|
||||||
with open(os.path.join(PATH_TO_TRANFORMERS, module_fname), "r", encoding="utf-8") as f:
|
repo = Repo(PATH_TO_REPO)
|
||||||
|
|
||||||
|
if not diff_with_last_commit:
|
||||||
|
print(f"main is at {repo.refs.main.commit}")
|
||||||
|
print(f"Current head is at {repo.head.commit}")
|
||||||
|
|
||||||
|
branching_commits = repo.merge_base(repo.refs.main, repo.head)
|
||||||
|
for commit in branching_commits:
|
||||||
|
print(f"Branching commit: {commit}")
|
||||||
|
return get_diff(repo, repo.head.commit, branching_commits)
|
||||||
|
else:
|
||||||
|
print(f"main is at {repo.head.commit}")
|
||||||
|
parent_commits = repo.head.commit.parents
|
||||||
|
for commit in parent_commits:
|
||||||
|
print(f"Parent commit: {commit}")
|
||||||
|
return get_diff(repo, repo.head.commit, parent_commits)
|
||||||
|
|
||||||
|
|
||||||
|
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
|
||||||
|
# \s*from\s+(\.+\S+)\s+import\s+([^\n]+) -> Line only contains from .xxx import yyy and we catch .xxx and yyy
|
||||||
|
# (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every
|
||||||
|
# other import.
|
||||||
|
_re_single_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+([^\n]+)(?=\n)")
|
||||||
|
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
|
||||||
|
# \s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\) -> Line continues with from .xxx import (yyy) and we catch .xxx and yyy
|
||||||
|
# yyy will take multiple lines otherwise there wouldn't be parenthesis.
|
||||||
|
_re_multi_line_relative_imports = re.compile(r"(?:^|\n)\s*from\s+(\.+\S+)\s+import\s+\(([^\)]+)\)")
|
||||||
|
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
|
||||||
|
# \s*from\s+transformers(\S*)\s+import\s+([^\n]+) -> Line only contains from transformers.xxx import yyy and we catch
|
||||||
|
# .xxx and yyy
|
||||||
|
# (?=\n) -> Look-ahead to a new line. We can't just put \n here or using find_all on this re will only catch every
|
||||||
|
# other import.
|
||||||
|
_re_single_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+([^\n]+)(?=\n)")
|
||||||
|
# (:?^|\n) -> Non-catching group for the beginning of the doc or a new line.
|
||||||
|
# \s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\) -> Line continues with from transformers.xxx import (yyy) and we
|
||||||
|
# catch .xxx and yyy. yyy will take multiple lines otherwise there wouldn't be parenthesis.
|
||||||
|
_re_multi_line_direct_imports = re.compile(r"(?:^|\n)\s*from\s+transformers(\S*)\s+import\s+\(([^\)]+)\)")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_imports(module_fname, cache=None):
|
||||||
|
"""
|
||||||
|
Get the imports a given module makes. This takes a module filename and returns the list of module filenames
|
||||||
|
imported in the module with the objects imported in that module filename.
|
||||||
|
"""
|
||||||
|
if cache is not None and module_fname in cache:
|
||||||
|
return cache[module_fname]
|
||||||
|
|
||||||
|
with open(PATH_TO_REPO / module_fname, "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
module_parts = module_fname.split(os.path.sep)
|
# Filter out all docstrings to not get imports in code examples.
|
||||||
|
splits = content.split('"""')
|
||||||
|
content = "".join(splits[::2])
|
||||||
|
|
||||||
|
module_parts = str(module_fname).split(os.path.sep)
|
||||||
imported_modules = []
|
imported_modules = []
|
||||||
|
|
||||||
# Let's start with relative imports
|
# Let's start with relative imports
|
||||||
relative_imports = re.findall(r"from\s+(\.+\S+)\s+import\s+([^\n]+)\n", content)
|
relative_imports = _re_single_line_relative_imports.findall(content)
|
||||||
relative_imports = [mod for mod, imp in relative_imports if "# tests_ignore" not in imp]
|
relative_imports = [
|
||||||
for imp in relative_imports:
|
(mod, imp) for mod, imp in relative_imports if "# tests_ignore" not in imp and imp.strip() != "("
|
||||||
|
]
|
||||||
|
multiline_relative_imports = _re_multi_line_relative_imports.findall(content)
|
||||||
|
relative_imports += [(mod, imp) for mod, imp in multiline_relative_imports if "# tests_ignore" not in imp]
|
||||||
|
|
||||||
|
for module, imports in relative_imports:
|
||||||
level = 0
|
level = 0
|
||||||
while imp.startswith("."):
|
while module.startswith("."):
|
||||||
imp = imp[1:]
|
module = module[1:]
|
||||||
level += 1
|
level += 1
|
||||||
|
|
||||||
if len(imp) > 0:
|
if len(module) > 0:
|
||||||
dep_parts = module_parts[: len(module_parts) - level] + imp.split(".")
|
dep_parts = module_parts[: len(module_parts) - level] + module.split(".")
|
||||||
else:
|
else:
|
||||||
dep_parts = module_parts[: len(module_parts) - level] + ["__init__.py"]
|
dep_parts = module_parts[: len(module_parts) - level]
|
||||||
imported_module = os.path.sep.join(dep_parts)
|
imported_module = os.path.sep.join(dep_parts)
|
||||||
# We ignore the main init import as it's only for the __version__ that it's done
|
imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
|
||||||
# and it would add everything as a dependency.
|
|
||||||
if not imported_module.endswith("transformers/__init__.py"):
|
|
||||||
imported_modules.append(imported_module)
|
|
||||||
|
|
||||||
# Let's continue with direct imports
|
# Let's continue with direct imports
|
||||||
# The import from the transformers module are ignored for the same reason we ignored the
|
direct_imports = _re_single_line_direct_imports.findall(content)
|
||||||
# main init before.
|
direct_imports = [(mod, imp) for mod, imp in direct_imports if "# tests_ignore" not in imp and imp.strip() != "("]
|
||||||
direct_imports = re.findall(r"from\s+transformers\.(\S+)\s+import\s+([^\n]+)\n", content)
|
multiline_direct_imports = _re_multi_line_direct_imports.findall(content)
|
||||||
direct_imports = [mod for mod, imp in direct_imports if "# tests_ignore" not in imp]
|
direct_imports += [(mod, imp) for mod, imp in multiline_direct_imports if "# tests_ignore" not in imp]
|
||||||
for imp in direct_imports:
|
|
||||||
import_parts = imp.split(".")
|
for module, imports in direct_imports:
|
||||||
|
import_parts = module.split(".")[1:] # ignore the first .
|
||||||
dep_parts = ["src", "transformers"] + import_parts
|
dep_parts = ["src", "transformers"] + import_parts
|
||||||
imported_modules.append(os.path.sep.join(dep_parts))
|
imported_module = os.path.sep.join(dep_parts)
|
||||||
|
imported_modules.append((imported_module, [imp.strip() for imp in imports.split(",")]))
|
||||||
|
|
||||||
# Now let's just check that we have proper module files, or append an init for submodules
|
result = []
|
||||||
|
for module_file, imports in imported_modules:
|
||||||
|
if (PATH_TO_REPO / f"{module_file}.py").is_file():
|
||||||
|
module_file = f"{module_file}.py"
|
||||||
|
elif (PATH_TO_REPO / module_file).is_dir() and (PATH_TO_REPO / module_file / "__init__.py").is_file():
|
||||||
|
module_file = os.path.sep.join([module_file, "__init__.py"])
|
||||||
|
imports = [imp for imp in imports if len(imp) > 0 and re.match("^[A-Za-z0-9_]*$", imp)]
|
||||||
|
if len(imports) > 0:
|
||||||
|
result.append((module_file, imports))
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
cache[module_fname] = result
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_dependencies(module_fname, cache=None):
|
||||||
|
"""
|
||||||
|
Get the dependencies of a module from the module filename as a list of module filenames. This will resolve any
|
||||||
|
__init__ we pass: if we import from a submodule utils, the dependencies will be utils/foo.py and utils/bar.py (if
|
||||||
|
the objects imported actually come from utils.foo and utils.bar) not utils/__init__.py.
|
||||||
|
"""
|
||||||
dependencies = []
|
dependencies = []
|
||||||
for imported_module in imported_modules:
|
imported_modules = extract_imports(module_fname, cache=cache)
|
||||||
if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f"{imported_module}.py")):
|
# The while loop is to recursively traverse all inits we may encounter.
|
||||||
dependencies.append(f"{imported_module}.py")
|
while len(imported_modules) > 0:
|
||||||
elif os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, imported_module)) and os.path.isfile(
|
new_modules = []
|
||||||
os.path.sep.join([PATH_TO_TRANFORMERS, imported_module, "__init__.py"])
|
for module, imports in imported_modules:
|
||||||
):
|
# If we end up in an __init__ we are often not actually importing from this init (except in the case where
|
||||||
dependencies.append(os.path.sep.join([imported_module, "__init__.py"]))
|
# the object is fully defined in the __init__)
|
||||||
|
if module.endswith("__init__.py"):
|
||||||
|
# So we get the imports from that init then try to find where our objects come from.
|
||||||
|
new_imported_modules = extract_imports(module, cache=cache)
|
||||||
|
for new_module, new_imports in new_imported_modules:
|
||||||
|
if any([i in new_imports for i in imports]):
|
||||||
|
if new_module not in dependencies:
|
||||||
|
new_modules.append((new_module, [i for i in new_imports if i in imports]))
|
||||||
|
imports = [i for i in imports if i not in new_imports]
|
||||||
|
if len(imports) > 0:
|
||||||
|
# If there are any objects lefts, they may be a submodule
|
||||||
|
path_to_module = PATH_TO_REPO / module.replace("__init__.py", "")
|
||||||
|
dependencies.extend(
|
||||||
|
[
|
||||||
|
os.path.join(module.replace("__init__.py", ""), f"{i}.py")
|
||||||
|
for i in imports
|
||||||
|
if (path_to_module / f"{i}.py").is_file()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
imports = [i for i in imports if not (path_to_module / f"{i}.py").is_file()]
|
||||||
|
if len(imports) > 0:
|
||||||
|
# Then if there are still objects left, they are fully defined in the init, so we keep it as a
|
||||||
|
# dependency.
|
||||||
|
dependencies.append(module)
|
||||||
|
else:
|
||||||
|
dependencies.append(module)
|
||||||
|
|
||||||
|
imported_modules = new_modules
|
||||||
return dependencies
|
return dependencies
|
||||||
|
|
||||||
|
|
||||||
def get_test_dependencies(test_fname):
|
|
||||||
"""
|
|
||||||
Get the dependencies of a test file.
|
|
||||||
"""
|
|
||||||
with open(os.path.join(PATH_TO_TRANFORMERS, test_fname), "r", encoding="utf-8") as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# Tests only have relative imports for other test files
|
|
||||||
# TODO Sylvain: handle relative imports cleanly
|
|
||||||
relative_imports = re.findall(r"from\s+(\.\S+)\s+import\s+([^\n]+)\n", content)
|
|
||||||
relative_imports = [test for test, imp in relative_imports if "# tests_ignore" not in imp]
|
|
||||||
|
|
||||||
def _convert_relative_import_to_file(relative_import):
|
|
||||||
level = 0
|
|
||||||
while relative_import.startswith("."):
|
|
||||||
level += 1
|
|
||||||
relative_import = relative_import[1:]
|
|
||||||
|
|
||||||
directory = os.path.sep.join(test_fname.split(os.path.sep)[:-level])
|
|
||||||
return os.path.join(directory, f"{relative_import.replace('.', os.path.sep)}.py")
|
|
||||||
|
|
||||||
dependencies = [_convert_relative_import_to_file(relative_import) for relative_import in relative_imports]
|
|
||||||
return [f for f in dependencies if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f))]
|
|
||||||
|
|
||||||
|
|
||||||
def create_reverse_dependency_tree():
|
def create_reverse_dependency_tree():
|
||||||
"""
|
"""
|
||||||
Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
|
Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
|
||||||
"""
|
"""
|
||||||
modules = [
|
cache = {}
|
||||||
str(f.relative_to(PATH_TO_TRANFORMERS))
|
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
|
||||||
for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
|
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
|
||||||
]
|
edges = [(dep, mod) for mod in all_modules for dep in get_module_dependencies(mod, cache=cache)]
|
||||||
module_edges = [(d, m) for m in modules for d in get_module_dependencies(m)]
|
|
||||||
|
|
||||||
tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
|
return list(set(edges))
|
||||||
test_edges = [(d, t) for t in tests for d in get_test_dependencies(t)]
|
|
||||||
|
|
||||||
return module_edges + test_edges
|
|
||||||
|
|
||||||
|
|
||||||
def get_tree_starting_at(module, edges):
|
def get_tree_starting_at(module, edges):
|
||||||
@ -264,13 +365,17 @@ def get_tree_starting_at(module, edges):
|
|||||||
starting at module], [list of edges starting at the preceding level], ...]
|
starting at module], [list of edges starting at the preceding level], ...]
|
||||||
"""
|
"""
|
||||||
vertices_seen = [module]
|
vertices_seen = [module]
|
||||||
new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module]
|
new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module and "__init__.py" not in edge[1]]
|
||||||
tree = [module]
|
tree = [module]
|
||||||
while len(new_edges) > 0:
|
while len(new_edges) > 0:
|
||||||
tree.append(new_edges)
|
tree.append(new_edges)
|
||||||
final_vertices = list({edge[1] for edge in new_edges})
|
final_vertices = list({edge[1] for edge in new_edges})
|
||||||
vertices_seen.extend(final_vertices)
|
vertices_seen.extend(final_vertices)
|
||||||
new_edges = [edge for edge in edges if edge[0] in final_vertices and edge[1] not in vertices_seen]
|
new_edges = [
|
||||||
|
edge
|
||||||
|
for edge in edges
|
||||||
|
if edge[0] in final_vertices and edge[1] not in vertices_seen and "__init__.py" not in edge[1]
|
||||||
|
]
|
||||||
|
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
@ -308,290 +413,159 @@ def create_reverse_dependency_map():
|
|||||||
Create the dependency map from module/test filename to the list of modules/tests that depend on it (even
|
Create the dependency map from module/test filename to the list of modules/tests that depend on it (even
|
||||||
recursively).
|
recursively).
|
||||||
"""
|
"""
|
||||||
modules = [
|
cache = {}
|
||||||
str(f.relative_to(PATH_TO_TRANFORMERS))
|
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
|
||||||
for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
|
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
|
||||||
]
|
direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
|
||||||
# We grab all the dependencies of each module.
|
|
||||||
direct_deps = {m: get_module_dependencies(m) for m in modules}
|
|
||||||
|
|
||||||
# We add all the dependencies of each test file
|
|
||||||
tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
|
|
||||||
direct_deps.update({t: get_test_dependencies(t) for t in tests})
|
|
||||||
|
|
||||||
all_files = modules + tests
|
|
||||||
|
|
||||||
# This recurses the dependencies
|
# This recurses the dependencies
|
||||||
something_changed = True
|
something_changed = True
|
||||||
while something_changed:
|
while something_changed:
|
||||||
something_changed = False
|
something_changed = False
|
||||||
for m in all_files:
|
for m in all_modules:
|
||||||
for d in direct_deps[m]:
|
for d in direct_deps[m]:
|
||||||
|
if d.endswith("__init__.py"):
|
||||||
|
continue
|
||||||
if d not in direct_deps:
|
if d not in direct_deps:
|
||||||
raise ValueError(f"KeyError:{d}. From {m}")
|
raise ValueError(f"KeyError:{d}. From {m}")
|
||||||
for dep in direct_deps[d]:
|
new_deps = set(direct_deps[d]) - set(direct_deps[m])
|
||||||
if dep not in direct_deps[m]:
|
if len(new_deps) > 0:
|
||||||
direct_deps[m].append(dep)
|
direct_deps[m].extend(list(new_deps))
|
||||||
something_changed = True
|
something_changed = True
|
||||||
|
|
||||||
# Finally we can build the reverse map.
|
# Finally we can build the reverse map.
|
||||||
reverse_map = collections.defaultdict(list)
|
reverse_map = collections.defaultdict(list)
|
||||||
for m in all_files:
|
for m in all_modules:
|
||||||
if m.endswith("__init__.py"):
|
|
||||||
reverse_map[m].extend(direct_deps[m])
|
|
||||||
for d in direct_deps[m]:
|
for d in direct_deps[m]:
|
||||||
reverse_map[d].append(m)
|
reverse_map[d].append(m)
|
||||||
|
|
||||||
|
for m in [f for f in all_modules if f.endswith("__init__.py")]:
|
||||||
|
direct_deps = get_module_dependencies(m, cache=cache)
|
||||||
|
deps = sum([reverse_map[d] for d in direct_deps if not d.endswith("__init__.py")], direct_deps)
|
||||||
|
reverse_map[m] = list(set(deps) - {m})
|
||||||
|
|
||||||
return reverse_map
|
return reverse_map
|
||||||
|
|
||||||
|
|
||||||
# Any module file that has a test name which can't be inferred automatically from its name should go here. A better
|
def create_module_to_test_map(reverse_map=None, filter_models=False):
|
||||||
# approach is to (re-)name the test file accordingly, and second best to add the correspondence map here.
|
|
||||||
SPECIAL_MODULE_TO_TEST_MAP = {
|
|
||||||
"commands/add_new_model_like.py": "utils/test_add_new_model_like.py",
|
|
||||||
"configuration_utils.py": "test_configuration_common.py",
|
|
||||||
"convert_graph_to_onnx.py": "onnx/test_onnx.py",
|
|
||||||
"data/data_collator.py": "trainer/test_data_collator.py",
|
|
||||||
"deepspeed.py": "deepspeed/",
|
|
||||||
"feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py",
|
|
||||||
"feature_extraction_utils.py": "test_feature_extraction_common.py",
|
|
||||||
"file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
|
|
||||||
"image_processing_utils.py": ["test_image_processing_common.py", "utils/test_image_processing_utils.py"],
|
|
||||||
"image_transforms.py": "test_image_transforms.py",
|
|
||||||
"utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"],
|
|
||||||
"utils/hub.py": "utils/test_hub_utils.py",
|
|
||||||
"modelcard.py": "utils/test_model_card.py",
|
|
||||||
"modeling_flax_utils.py": "test_modeling_flax_common.py",
|
|
||||||
"modeling_tf_utils.py": ["test_modeling_tf_common.py", "utils/test_modeling_tf_core.py"],
|
|
||||||
"modeling_utils.py": ["test_modeling_common.py", "utils/test_offline.py"],
|
|
||||||
"models/auto/modeling_auto.py": [
|
|
||||||
"models/auto/test_modeling_auto.py",
|
|
||||||
"models/auto/test_modeling_tf_pytorch.py",
|
|
||||||
"models/bort/test_modeling_bort.py",
|
|
||||||
"models/dit/test_modeling_dit.py",
|
|
||||||
],
|
|
||||||
"models/auto/modeling_flax_auto.py": "models/auto/test_modeling_flax_auto.py",
|
|
||||||
"models/auto/modeling_tf_auto.py": [
|
|
||||||
"models/auto/test_modeling_tf_auto.py",
|
|
||||||
"models/auto/test_modeling_tf_pytorch.py",
|
|
||||||
"models/bort/test_modeling_tf_bort.py",
|
|
||||||
],
|
|
||||||
"models/gpt2/modeling_gpt2.py": [
|
|
||||||
"models/gpt2/test_modeling_gpt2.py",
|
|
||||||
"models/megatron_gpt2/test_modeling_megatron_gpt2.py",
|
|
||||||
],
|
|
||||||
"models/dpt/modeling_dpt.py": [
|
|
||||||
"models/dpt/test_modeling_dpt.py",
|
|
||||||
"models/dpt/test_modeling_dpt_hybrid.py",
|
|
||||||
],
|
|
||||||
"optimization.py": "optimization/test_optimization.py",
|
|
||||||
"optimization_tf.py": "optimization/test_optimization_tf.py",
|
|
||||||
"pipelines/__init__.py": all_pipeline_test_files + all_model_test_files,
|
|
||||||
"pipelines/base.py": all_pipeline_test_files + all_model_test_files,
|
|
||||||
"pipelines/text2text_generation.py": [
|
|
||||||
"pipelines/test_pipelines_text2text_generation.py",
|
|
||||||
"pipelines/test_pipelines_summarization.py",
|
|
||||||
"pipelines/test_pipelines_translation.py",
|
|
||||||
],
|
|
||||||
"pipelines/zero_shot_classification.py": "pipelines/test_pipelines_zero_shot.py",
|
|
||||||
"testing_utils.py": "utils/test_skip_decorators.py",
|
|
||||||
"tokenization_utils.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"],
|
|
||||||
"tokenization_utils_base.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"],
|
|
||||||
"tokenization_utils_fast.py": [
|
|
||||||
"test_tokenization_common.py",
|
|
||||||
"tokenization/test_tokenization_utils.py",
|
|
||||||
"tokenization/test_tokenization_fast.py",
|
|
||||||
],
|
|
||||||
"trainer.py": [
|
|
||||||
"trainer/test_trainer.py",
|
|
||||||
"extended/test_trainer_ext.py",
|
|
||||||
"trainer/test_trainer_distributed.py",
|
|
||||||
"trainer/test_trainer_tpu.py",
|
|
||||||
],
|
|
||||||
"train_pt_utils.py": "trainer/test_trainer_utils.py",
|
|
||||||
"utils/versions.py": "utils/test_versions_utils.py",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def module_to_test_file(module_fname):
|
|
||||||
"""
|
"""
|
||||||
Returns the name of the file(s) where `module_fname` is tested.
|
Extract the tests from the reverse_dependency_map and potentially filters the model tests.
|
||||||
"""
|
"""
|
||||||
splits = module_fname.split(os.path.sep)
|
if reverse_map is None:
|
||||||
|
reverse_map = create_reverse_dependency_map()
|
||||||
|
test_map = {module: [f for f in deps if f.startswith("tests")] for module, deps in reverse_map.items()}
|
||||||
|
|
||||||
# Special map has priority
|
if not filter_models:
|
||||||
short_name = os.path.sep.join(splits[2:])
|
return test_map
|
||||||
if short_name in SPECIAL_MODULE_TO_TEST_MAP:
|
|
||||||
test_file = SPECIAL_MODULE_TO_TEST_MAP[short_name]
|
|
||||||
if isinstance(test_file, str):
|
|
||||||
return f"tests/{test_file}"
|
|
||||||
return [f"tests/{f}" for f in test_file]
|
|
||||||
|
|
||||||
module_name = splits[-1]
|
num_model_tests = len(list(PATH_TO_TESTS.glob("models/*")))
|
||||||
# Fast tokenizers are tested in the same file as the slow ones.
|
|
||||||
if module_name.endswith("_fast.py"):
|
|
||||||
module_name = module_name.replace("_fast.py", ".py")
|
|
||||||
|
|
||||||
# Special case for pipelines submodules
|
def has_many_models(tests):
|
||||||
if len(splits) >= 2 and splits[-2] == "pipelines":
|
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
|
||||||
default_test_file = f"tests/pipelines/test_pipelines_{module_name}"
|
return len(model_tests) > num_model_tests // 2
|
||||||
return [default_test_file] + all_model_test_files
|
|
||||||
# Special case for benchmarks submodules
|
|
||||||
elif len(splits) >= 2 and splits[-2] == "benchmark":
|
|
||||||
return ["tests/benchmark/test_benchmark.py", "tests/benchmark/test_benchmark_tf.py"]
|
|
||||||
# Special case for commands submodules
|
|
||||||
elif len(splits) >= 2 and splits[-2] == "commands":
|
|
||||||
return "tests/utils/test_cli.py"
|
|
||||||
# Special case for onnx submodules
|
|
||||||
elif len(splits) >= 2 and splits[-2] == "onnx":
|
|
||||||
return ["tests/onnx/test_features.py", "tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
|
|
||||||
# Special case for utils (not the one in src/transformers, the ones at the root of the repo).
|
|
||||||
elif len(splits) > 0 and splits[0] == "utils":
|
|
||||||
default_test_file = f"tests/repo_utils/test_{module_name}"
|
|
||||||
elif len(splits) > 4 and splits[2] == "models":
|
|
||||||
default_test_file = f"tests/models/{splits[3]}/test_{module_name}"
|
|
||||||
elif len(splits) > 2 and splits[2].startswith("generation"):
|
|
||||||
default_test_file = f"tests/generation/test_{module_name}"
|
|
||||||
elif len(splits) > 2 and splits[2].startswith("trainer"):
|
|
||||||
default_test_file = f"tests/trainer/test_{module_name}"
|
|
||||||
else:
|
|
||||||
default_test_file = f"tests/utils/test_{module_name}"
|
|
||||||
|
|
||||||
if os.path.isfile(default_test_file):
|
def filter_tests(tests):
|
||||||
return default_test_file
|
return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_MODELS]
|
||||||
|
|
||||||
# Processing -> processor
|
return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()}
|
||||||
if "processing" in default_test_file:
|
|
||||||
test_file = default_test_file.replace("processing", "processor")
|
|
||||||
if os.path.isfile(test_file):
|
|
||||||
return test_file
|
|
||||||
|
|
||||||
|
|
||||||
# This list contains the list of test files we expect never to be launched from a change in a module/util. Those are
|
def check_imports_all_exist():
|
||||||
# launched separately.
|
"""
|
||||||
EXPECTED_TEST_FILES_NEVER_TOUCHED = [
|
Isn't used per se by the test fetcher but might be used later as a quality check. Putting this here for now so the
|
||||||
"tests/generation/test_framework_agnostic.py", # Mixins inherited by actual test classes
|
code is not lost.
|
||||||
"tests/mixed_int8/test_mixed_int8.py", # Mixed-int8 bitsandbytes test
|
"""
|
||||||
"tests/pipelines/test_pipelines_common.py", # Actually checked by the pipeline based file
|
cache = {}
|
||||||
"tests/sagemaker/test_single_node_gpu.py", # SageMaker test
|
all_modules = list(PATH_TO_TRANFORMERS.glob("**/*.py")) + list(PATH_TO_TESTS.glob("**/*.py"))
|
||||||
"tests/sagemaker/test_multi_node_model_parallel.py", # SageMaker test
|
all_modules = [str(mod.relative_to(PATH_TO_REPO)) for mod in all_modules]
|
||||||
"tests/sagemaker/test_multi_node_data_parallel.py", # SageMaker test
|
direct_deps = {m: get_module_dependencies(m, cache=cache) for m in all_modules}
|
||||||
"tests/test_pipeline_mixin.py", # Contains no test of its own (only the common tester class)
|
|
||||||
"tests/utils/test_doc_samples.py", # Doc tests
|
for module, deps in direct_deps.items():
|
||||||
]
|
for dep in deps:
|
||||||
|
if not (PATH_TO_REPO / dep).is_file():
|
||||||
|
print(f"{module} has dependency on {dep} which does not exist.")
|
||||||
|
|
||||||
|
|
||||||
def _print_list(l):
|
def _print_list(l):
|
||||||
return "\n".join([f"- {f}" for f in l])
|
return "\n".join([f"- {f}" for f in l])
|
||||||
|
|
||||||
|
|
||||||
def sanity_check():
|
def create_json_map(test_files_to_run, json_output_file):
|
||||||
"""
|
if json_output_file is None:
|
||||||
Checks that all test files can be touched by a modification in at least one module/utils. This test ensures that
|
return
|
||||||
newly-added test files are properly mapped to some module or utils, so they can be run by the CI.
|
|
||||||
"""
|
|
||||||
# Grab all module and utils
|
|
||||||
all_files = [
|
|
||||||
str(p.relative_to(PATH_TO_TRANFORMERS))
|
|
||||||
for p in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
|
|
||||||
]
|
|
||||||
all_files += [
|
|
||||||
str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "utils").glob("**/*.py")
|
|
||||||
]
|
|
||||||
|
|
||||||
# Compute all the test files we get from those.
|
test_map = {}
|
||||||
test_files_found = []
|
for test_file in test_files_to_run:
|
||||||
for f in all_files:
|
# `test_file` is a path to a test folder/file, starting with `tests/`. For example,
|
||||||
test_f = module_to_test_file(f)
|
# - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
|
||||||
if test_f is not None:
|
# - `tests/trainer/test_trainer.py` or `tests/trainer`
|
||||||
if isinstance(test_f, str):
|
# - `tests/test_modeling_common.py`
|
||||||
test_files_found.append(test_f)
|
names = test_file.split(os.path.sep)
|
||||||
else:
|
if names[1] == "models":
|
||||||
test_files_found.extend(test_f)
|
# take the part like `models/bert` for modeling tests
|
||||||
|
key = os.path.sep.join(names[1:3])
|
||||||
# Some of the test files might actually be subfolders so we grab the tests inside.
|
elif len(names) > 2 or not test_file.endswith(".py"):
|
||||||
test_files = []
|
# test folders under `tests` or python files under them
|
||||||
for test_f in test_files_found:
|
# take the part like tokenization, `pipeline`, etc. for other test categories
|
||||||
if os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, test_f)):
|
key = os.path.sep.join(names[1:2])
|
||||||
test_files.extend(
|
|
||||||
[
|
|
||||||
str(p.relative_to(PATH_TO_TRANFORMERS))
|
|
||||||
for p in (Path(PATH_TO_TRANFORMERS) / test_f).glob("**/test*.py")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
test_files.append(test_f)
|
# common test files directly under `tests/`
|
||||||
|
key = "common"
|
||||||
|
|
||||||
# Compare to existing test files
|
if key not in test_map:
|
||||||
existing_test_files = [
|
test_map[key] = []
|
||||||
str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/test*.py")
|
test_map[key].append(test_file)
|
||||||
]
|
|
||||||
not_touched_test_files = [f for f in existing_test_files if f not in test_files]
|
|
||||||
|
|
||||||
should_be_tested = set(not_touched_test_files) - set(EXPECTED_TEST_FILES_NEVER_TOUCHED)
|
# sort the keys & values
|
||||||
if len(should_be_tested) > 0:
|
keys = sorted(test_map.keys())
|
||||||
raise ValueError(
|
test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
|
||||||
"The following test files are not currently associated with any module or utils files, which means they "
|
with open(json_output_file, "w", encoding="UTF-8") as fp:
|
||||||
f"will never get run by the CI:\n{_print_list(should_be_tested)}\n. Make sure the names of these test "
|
json.dump(test_map, fp, ensure_ascii=False)
|
||||||
"files match the name of the module or utils they are testing, or adapt the constant "
|
|
||||||
"`SPECIAL_MODULE_TO_TEST_MAP` in `utils/tests_fetcher.py` to add them. If your test file is triggered "
|
|
||||||
"separately and is not supposed to be run by the regular CI, add it to the "
|
|
||||||
"`EXPECTED_TEST_FILES_NEVER_TOUCHED` constant instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, json_output_file=None):
|
def infer_tests_to_run(
|
||||||
|
output_file, diff_with_last_commit=False, filters=None, filter_models=True, json_output_file=None
|
||||||
|
):
|
||||||
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
|
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
|
||||||
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
|
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
|
||||||
|
|
||||||
# Create the map that will give us all impacted modules.
|
# Create the map that will give us all impacted modules.
|
||||||
impacted_modules_map = create_reverse_dependency_map()
|
reverse_map = create_reverse_dependency_map()
|
||||||
impacted_files = modified_files.copy()
|
impacted_files = modified_files.copy()
|
||||||
for f in modified_files:
|
for f in modified_files:
|
||||||
if f in impacted_modules_map:
|
if f in reverse_map:
|
||||||
impacted_files.extend(impacted_modules_map[f])
|
impacted_files.extend(reverse_map[f])
|
||||||
|
|
||||||
# Remove duplicates
|
# Remove duplicates
|
||||||
impacted_files = sorted(set(impacted_files))
|
impacted_files = sorted(set(impacted_files))
|
||||||
print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")
|
print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")
|
||||||
|
|
||||||
# Grab the corresponding test files:
|
# Grab the corresponding test files:
|
||||||
if "setup.py" in impacted_files:
|
if "setup.py" in modified_files:
|
||||||
test_files_to_run = ["tests"]
|
test_files_to_run = ["tests"]
|
||||||
repo_utils_launch = True
|
repo_utils_launch = True
|
||||||
else:
|
else:
|
||||||
# Grab the corresponding test files:
|
# All modified tests need to be run.
|
||||||
test_files_to_run = []
|
test_files_to_run = [
|
||||||
for f in impacted_files:
|
f for f in modified_files if f.startswith("tests") and f.split(os.path.sep)[-1].startswith("test")
|
||||||
# Modified test files are always added
|
]
|
||||||
if f.startswith("tests/"):
|
# Then we grab the corresponding test files.
|
||||||
test_files_to_run.append(f)
|
test_map = create_module_to_test_map(reverse_map=reverse_map, filter_models=filter_models)
|
||||||
# Example files are tested separately
|
for f in modified_files:
|
||||||
elif f.startswith("examples/pytorch"):
|
if f in test_map:
|
||||||
test_files_to_run.append("examples/pytorch/test_pytorch_examples.py")
|
test_files_to_run.extend(test_map[f])
|
||||||
test_files_to_run.append("examples/pytorch/test_accelerate_examples.py")
|
|
||||||
elif f.startswith("examples/tensorflow"):
|
|
||||||
test_files_to_run.append("examples/tensorflow/test_tensorflow_examples.py")
|
|
||||||
elif f.startswith("examples/flax"):
|
|
||||||
test_files_to_run.append("examples/flax/test_flax_examples.py")
|
|
||||||
else:
|
|
||||||
new_tests = module_to_test_file(f)
|
|
||||||
if new_tests is not None:
|
|
||||||
if isinstance(new_tests, str):
|
|
||||||
test_files_to_run.append(new_tests)
|
|
||||||
else:
|
|
||||||
test_files_to_run.extend(new_tests)
|
|
||||||
|
|
||||||
# Remove duplicates
|
|
||||||
test_files_to_run = sorted(set(test_files_to_run))
|
test_files_to_run = sorted(set(test_files_to_run))
|
||||||
|
# Remove SageMaker tests
|
||||||
|
test_files_to_run = [f for f in test_files_to_run if not f.split(os.path.sep)[1] == "sagemaker"]
|
||||||
# Make sure we did not end up with a test file that was removed
|
# Make sure we did not end up with a test file that was removed
|
||||||
test_files_to_run = [f for f in test_files_to_run if os.path.isfile(f) or os.path.isdir(f)]
|
test_files_to_run = [f for f in test_files_to_run if (PATH_TO_REPO / f).exists()]
|
||||||
if filters is not None:
|
if filters is not None:
|
||||||
filtered_files = []
|
filtered_files = []
|
||||||
for filter in filters:
|
for _filter in filters:
|
||||||
filtered_files.extend([f for f in test_files_to_run if f.startswith(filter)])
|
filtered_files.extend([f for f in test_files_to_run if f.startswith(_filter)])
|
||||||
test_files_to_run = filtered_files
|
test_files_to_run = filtered_files
|
||||||
repo_utils_launch = any(f.split(os.path.sep)[1] == "repo_utils" for f in test_files_to_run)
|
|
||||||
|
repo_utils_launch = any(f.split(os.path.sep)[1] == "repo_utils" for f in modified_files)
|
||||||
|
|
||||||
if repo_utils_launch:
|
if repo_utils_launch:
|
||||||
repo_util_file = Path(output_file).parent / "test_repo_utils.txt"
|
repo_util_file = Path(output_file).parent / "test_repo_utils.txt"
|
||||||
@ -610,34 +584,7 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, j
|
|||||||
if "tests" in test_files_to_run:
|
if "tests" in test_files_to_run:
|
||||||
test_files_to_run = get_all_tests()
|
test_files_to_run = get_all_tests()
|
||||||
|
|
||||||
if json_output_file is not None:
|
create_json_map(test_files_to_run, json_output_file)
|
||||||
test_map = {}
|
|
||||||
for test_file in test_files_to_run:
|
|
||||||
# `test_file` is a path to a test folder/file, starting with `tests/`. For example,
|
|
||||||
# - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
|
|
||||||
# - `tests/trainer/test_trainer.py` or `tests/trainer`
|
|
||||||
# - `tests/test_modeling_common.py`
|
|
||||||
names = test_file.split(os.path.sep)
|
|
||||||
if names[1] == "models":
|
|
||||||
# take the part like `models/bert` for modeling tests
|
|
||||||
key = "/".join(names[1:3])
|
|
||||||
elif len(names) > 2 or not test_file.endswith(".py"):
|
|
||||||
# test folders under `tests` or python files under them
|
|
||||||
# take the part like tokenization, `pipeline`, etc. for other test categories
|
|
||||||
key = "/".join(names[1:2])
|
|
||||||
else:
|
|
||||||
# common test files directly under `tests/`
|
|
||||||
key = "common"
|
|
||||||
|
|
||||||
if key not in test_map:
|
|
||||||
test_map[key] = []
|
|
||||||
test_map[key].append(test_file)
|
|
||||||
|
|
||||||
# sort the keys & values
|
|
||||||
keys = sorted(test_map.keys())
|
|
||||||
test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
|
|
||||||
with open(json_output_file, "w", encoding="UTF-8") as fp:
|
|
||||||
json.dump(test_map, fp, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_tests(output_file, filters):
|
def filter_tests(output_file, filters):
|
||||||
@ -667,11 +614,29 @@ def filter_tests(output_file, filters):
|
|||||||
f.write(" ".join(test_files))
|
f.write(" ".join(test_files))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_commit_message(commit_message):
|
||||||
|
"""
|
||||||
|
Parses the commit message to detect if a command is there to skip, force all or part of the CI.
|
||||||
|
|
||||||
|
Returns a dictionary of strings to bools with keys skip, test_all_models and test_all.
|
||||||
|
"""
|
||||||
|
if commit_message is None:
|
||||||
|
return {"skip": False, "no_filter": False, "test_all": False}
|
||||||
|
|
||||||
|
command_search = re.search(r"\[([^\]]*)\]", commit_message)
|
||||||
|
if command_search is not None:
|
||||||
|
command = command_search.groups()[0]
|
||||||
|
command = command.lower().replace("-", " ").replace("_", " ")
|
||||||
|
skip = command in ["ci skip", "skip ci", "circleci skip", "skip circleci"]
|
||||||
|
no_filter = set(command.split(" ")) == {"no", "filter"}
|
||||||
|
test_all = set(command.split(" ")) == {"test", "all"}
|
||||||
|
return {"skip": skip, "no_filter": no_filter, "test_all": test_all}
|
||||||
|
else:
|
||||||
|
return {"skip": False, "no_filter": False, "test_all": False}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
|
||||||
"--sanity_check", action="store_true", help="Only test that all tests and modules are accounted for."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run"
|
"--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run"
|
||||||
)
|
)
|
||||||
@ -704,33 +669,54 @@ if __name__ == "__main__":
|
|||||||
help="Will only print the tree of modules depending on the file passed.",
|
help="Will only print the tree of modules depending on the file passed.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--commit_message",
|
||||||
|
type=str,
|
||||||
|
help="The commit message (which could contain a command to force all tests or skip the CI).",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.print_dependencies_of is not None:
|
if args.print_dependencies_of is not None:
|
||||||
print_tree_deps_of(args.print_dependencies_of)
|
print_tree_deps_of(args.print_dependencies_of)
|
||||||
elif args.sanity_check:
|
|
||||||
sanity_check()
|
|
||||||
elif args.filter_tests:
|
elif args.filter_tests:
|
||||||
filter_tests(args.output_file, ["pipelines", "repo_utils"])
|
filter_tests(args.output_file, ["pipelines", "repo_utils"])
|
||||||
else:
|
else:
|
||||||
repo = Repo(PATH_TO_TRANFORMERS)
|
repo = Repo(PATH_TO_REPO)
|
||||||
|
commit_message = repo.head.commit.message
|
||||||
|
commit_flags = parse_commit_message(commit_message)
|
||||||
|
if commit_flags["skip"]:
|
||||||
|
print("Force-skipping the CI")
|
||||||
|
quit()
|
||||||
|
if commit_flags["no_filter"]:
|
||||||
|
print("Running all tests fetched without filtering.")
|
||||||
|
if commit_flags["test_all"]:
|
||||||
|
print("Force-launching all tests")
|
||||||
|
|
||||||
diff_with_last_commit = args.diff_with_last_commit
|
diff_with_last_commit = args.diff_with_last_commit
|
||||||
if not diff_with_last_commit and not repo.head.is_detached and repo.head.ref == repo.refs.main:
|
if not diff_with_last_commit and not repo.head.is_detached and repo.head.ref == repo.refs.main:
|
||||||
print("main branch detected, fetching tests against last commit.")
|
print("main branch detected, fetching tests against last commit.")
|
||||||
diff_with_last_commit = True
|
diff_with_last_commit = True
|
||||||
|
|
||||||
try:
|
if not commit_flags["test_all"]:
|
||||||
infer_tests_to_run(
|
try:
|
||||||
args.output_file,
|
infer_tests_to_run(
|
||||||
diff_with_last_commit=diff_with_last_commit,
|
args.output_file,
|
||||||
filters=args.filters,
|
diff_with_last_commit=diff_with_last_commit,
|
||||||
json_output_file=args.json_output_file,
|
filters=args.filters,
|
||||||
)
|
json_output_file=args.json_output_file,
|
||||||
filter_tests(args.output_file, ["repo_utils"])
|
filter_models=not commit_flags["no_filter"],
|
||||||
except Exception as e:
|
)
|
||||||
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
|
filter_tests(args.output_file, ["repo_utils"])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
|
||||||
|
commit_flags["test_all"] = True
|
||||||
|
|
||||||
|
if commit_flags["test_all"]:
|
||||||
with open(args.output_file, "w", encoding="utf-8") as f:
|
with open(args.output_file, "w", encoding="utf-8") as f:
|
||||||
if args.filters is None:
|
if args.filters is None:
|
||||||
f.write("./tests/")
|
f.write("./tests/")
|
||||||
else:
|
else:
|
||||||
f.write(" ".join(args.filters))
|
f.write(" ".join(args.filters))
|
||||||
|
|
||||||
|
test_files_to_run = get_all_tests()
|
||||||
|
create_json_map(test_files_to_run, args.json_output_file)
|
||||||
|
Loading…
Reference in New Issue
Block a user