[Modular] skip modular checks based on diff (#36130)

skip modular checks based on diff
This commit is contained in:
Joao Gante 2025-02-13 12:53:21 +00:00 committed by GitHub
parent 6397916dd2
commit d114a6f78e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 91 additions and 41 deletions

View File

@ -58,7 +58,7 @@ jobs:
- run:
name: "Prepare pipeline parameters"
command: |
python utils/process_test_artifacts.py
python utils/process_test_artifacts.py
# To avoid too long generated_config.yaml on the continuation orb, we pass the links to the artifacts as parameters.
# Otherwise the list of tests was just too big. Explicit is good but for that it was a limitation.
@ -110,7 +110,7 @@ jobs:
- run:
name: "Prepare pipeline parameters"
command: |
python utils/process_test_artifacts.py
python utils/process_test_artifacts.py
# To avoid too long generated_config.yaml on the continuation orb, we pass the links to the artifacts as parameters.
# Otherwise the list of tests was just too big. Explicit is good but for that it was a limitation.

View File

@ -48,7 +48,7 @@ def appear_after(model1: str, model2: str, priority_list: list[str]) -> bool:
class ConversionOrderTest(unittest.TestCase):
def test_conversion_order(self):
# Find the order
priority_list = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
priority_list, _ = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
# Extract just the model names
model_priority_list = [file.rsplit("modular_")[-1].replace(".py", "") for file in priority_list]

View File

@ -1024,40 +1024,6 @@ def convert_to_localized_md(model_list: str, localized_model_list: str, format_s
return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n"
def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> Tuple[str, int, int, List[str]]:
"""
Find the text in a file between two prompts.
Args:
filename (`str`): The name of the file to look into.
start_prompt (`str`): The string to look for that introduces the content looked for.
end_prompt (`str`): The string to look for that ends the content looked for.
Returns:
Tuple[str, int, int, List[str]]: The content between the two prompts, the index of the start line in the
original file, the index of the end line in the original file and the list of lines of that file.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start prompt.
start_index = 0
while not lines[start_index].startswith(start_prompt):
start_index += 1
start_index += 1
end_index = start_index
while not lines[end_index].startswith(end_prompt):
end_index += 1
end_index -= 1
while len(lines[start_index]) <= 1:
start_index += 1
while len(lines[end_index]) <= 1:
end_index -= 1
end_index += 1
return "".join(lines[start_index:end_index]), start_index, end_index, lines
# Map a model name with the name it has in the README for the check_readme check
SPECIAL_MODEL_NAMES = {
"Bert Generation": "BERT For Sequence Generation",

View File

@ -2,6 +2,7 @@ import argparse
import difflib
import glob
import logging
import subprocess
from io import StringIO
from create_dependency_mapping import find_priority_list
@ -61,6 +62,56 @@ def compare_files(modular_file_path, fix_and_overwrite=False):
return diff
def get_models_in_diff():
"""
Finds all models that have been modified in the diff.
Returns:
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
"""
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
modified_files = (
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
.decode("utf-8")
.split()
)
# Matches both modelling files and tests
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
model_names = set()
for file_path in relevant_modified_files:
model_name = file_path.split("/")[-2]
model_names.add(model_name)
return model_names
def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
"""
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
Model is in the diff -> not guaranteed to have no differences
Dependency is in the diff -> not guaranteed to have no differences
Otherwise -> guaranteed to have no differences
Args:
modular_file_path: The path to the modular file.
dependencies: A dictionary containing the dependencies of each modular file.
models_in_diff: A set containing the names of the models that have been modified.
Returns:
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
"""
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
if model_name in models_in_diff:
return False
for dep in dependencies[modular_file_path]:
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
dependency_model_name = dep.split(".")[-2]
if dependency_model_name in models_in_diff:
return False
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
parser.add_argument(
@ -72,9 +123,32 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.files == ["all"]:
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
# script will do nothing.
models_in_diff = get_models_in_diff()
if not models_in_diff:
console.print("[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]")
exit(0)
skipped_models = set()
non_matching_files = 0
for modular_file_path in find_priority_list(args.files):
ordered_files, dependencies = find_priority_list(args.files)
for modular_file_path in ordered_files:
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
if is_guaranteed_no_diff:
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
skipped_models.add(model_name)
continue
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
if non_matching_files and not args.fix_and_overwrite:
raise ValueError("Some diff and their modeling code did not match.")
if skipped_models:
console.print(
f"[bold green]Skipped {len(skipped_models)} models and their dependencies that are not in the diff: "
f"{', '.join(skipped_models)}[/bold green]"
)

View File

@ -55,6 +55,16 @@ def map_dependencies(py_files):
def find_priority_list(py_files):
"""
Given a list of modular files, sorts them by topological order. Modular models that DON'T depend on other modular
models will be higher in the topological order.
Args:
py_files: List of paths to the modular files
Returns:
A tuple with the ordered files (list) and their dependencies (dict)
"""
dependencies = map_dependencies(py_files)
ordered_classes = topological_sort(dependencies)
return ordered_classes
ordered_files = topological_sort(dependencies)
return ordered_files, dependencies

View File

@ -1716,7 +1716,7 @@ if __name__ == "__main__":
if args.files_to_parse == ["examples"]:
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True)
priority_list = find_priority_list(args.files_to_parse)
priority_list, _ = find_priority_list(args.files_to_parse)
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted"
for file_name in priority_list: