mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-02 04:10:06 +06:00

* Allow make-fixup on main branch, albeit slowly * Make the other style checks work correctly on main too * More update * More makefile update
188 lines
7.7 KiB
Python
188 lines
7.7 KiB
Python
import argparse
|
|
import difflib
|
|
import glob
|
|
import logging
|
|
import subprocess
|
|
from io import StringIO
|
|
|
|
from create_dependency_mapping import find_priority_list
|
|
|
|
# Console for rich printing
|
|
from modular_model_converter import convert_modular_file
|
|
from rich.console import Console
|
|
from rich.syntax import Syntax
|
|
|
|
|
|
logging.basicConfig()
|
|
logging.getLogger().setLevel(logging.ERROR)
|
|
console = Console()
|
|
|
|
|
|
def process_file(modular_file_path, generated_modeling_content, file_type="modeling_", fix_and_overwrite=False):
|
|
file_name_prefix = file_type.split("*")[0]
|
|
file_name_suffix = file_type.split("*")[-1] if "*" in file_type else ""
|
|
file_path = modular_file_path.replace("modular_", f"{file_name_prefix}_").replace(".py", f"{file_name_suffix}.py")
|
|
# Read the actual modeling file
|
|
with open(file_path, "r", encoding="utf-8") as modeling_file:
|
|
content = modeling_file.read()
|
|
output_buffer = StringIO(generated_modeling_content[file_type][0])
|
|
output_buffer.seek(0)
|
|
output_content = output_buffer.read()
|
|
diff = difflib.unified_diff(
|
|
output_content.splitlines(),
|
|
content.splitlines(),
|
|
fromfile=f"{file_path}_generated",
|
|
tofile=f"{file_path}",
|
|
lineterm="",
|
|
)
|
|
diff_list = list(diff)
|
|
# Check for differences
|
|
if diff_list:
|
|
if fix_and_overwrite:
|
|
with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file:
|
|
modeling_file.write(generated_modeling_content[file_type][0])
|
|
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
|
|
else:
|
|
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
|
|
diff_text = "\n".join(diff_list)
|
|
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
|
|
console.print(syntax)
|
|
return 1
|
|
else:
|
|
console.print(f"[bold green]No differences found for {file_path}.[/bold green]")
|
|
return 0
|
|
|
|
|
|
def compare_files(modular_file_path, fix_and_overwrite=False):
|
|
# Generate the expected modeling content
|
|
generated_modeling_content = convert_modular_file(modular_file_path)
|
|
diff = 0
|
|
for file_type in generated_modeling_content.keys():
|
|
diff += process_file(modular_file_path, generated_modeling_content, file_type, fix_and_overwrite)
|
|
return diff
|
|
|
|
|
|
def get_models_in_diff():
|
|
"""
|
|
Finds all models that have been modified in the diff.
|
|
|
|
Returns:
|
|
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
|
|
"""
|
|
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
|
|
modified_files = (
|
|
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
|
|
.decode("utf-8")
|
|
.split()
|
|
)
|
|
|
|
# Matches both modelling files and tests
|
|
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
|
|
model_names = set()
|
|
for file_path in relevant_modified_files:
|
|
model_name = file_path.split("/")[-2]
|
|
model_names.add(model_name)
|
|
return model_names
|
|
|
|
|
|
def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
|
|
"""
|
|
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
|
|
|
|
Model is in the diff -> not guaranteed to have no differences
|
|
Dependency is in the diff -> not guaranteed to have no differences
|
|
Otherwise -> guaranteed to have no differences
|
|
|
|
Args:
|
|
modular_file_path: The path to the modular file.
|
|
dependencies: A dictionary containing the dependencies of each modular file.
|
|
models_in_diff: A set containing the names of the models that have been modified.
|
|
|
|
Returns:
|
|
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
|
|
"""
|
|
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
|
if model_name in models_in_diff:
|
|
return False
|
|
for dep in dependencies[modular_file_path]:
|
|
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
|
|
dependency_model_name = dep.split(".")[-2]
|
|
if dependency_model_name in models_in_diff:
|
|
return False
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
|
|
parser.add_argument(
|
|
"--files", default=["all"], type=str, nargs="+", help="List of modular_xxx.py files to compare."
|
|
)
|
|
parser.add_argument(
|
|
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
|
|
)
|
|
parser.add_argument(
|
|
"--num_workers",
|
|
default=1,
|
|
type=int,
|
|
help="The number of workers to run. No effect if `fix_and_overwrite` is specified.",
|
|
)
|
|
args = parser.parse_args()
|
|
if args.files == ["all"]:
|
|
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
|
|
|
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
|
|
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
|
|
# script will do nothing.
|
|
current_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip()
|
|
if current_branch == "main":
|
|
console.print(
|
|
"[bold red]You are developing on the main branch. We cannot identify the list of changed files and will have to check all files. This may take a while.[/bold red]"
|
|
)
|
|
models_in_diff = {file_path.split("/")[-2] for file_path in args.files}
|
|
else:
|
|
models_in_diff = get_models_in_diff()
|
|
if not models_in_diff:
|
|
console.print(
|
|
"[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]"
|
|
)
|
|
exit(0)
|
|
|
|
skipped_models = set()
|
|
non_matching_files = 0
|
|
ordered_files, dependencies = find_priority_list(args.files)
|
|
if args.fix_and_overwrite or args.num_workers == 1:
|
|
for modular_file_path in ordered_files:
|
|
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
|
|
if is_guaranteed_no_diff:
|
|
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
|
skipped_models.add(model_name)
|
|
continue
|
|
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
|
if current_branch != "main":
|
|
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
|
|
else:
|
|
new_ordered_files = []
|
|
for modular_file_path in ordered_files:
|
|
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
|
|
if is_guaranteed_no_diff:
|
|
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
|
skipped_models.add(model_name)
|
|
else:
|
|
new_ordered_files.append(modular_file_path)
|
|
|
|
import multiprocessing
|
|
|
|
with multiprocessing.Pool(args.num_workers) as p:
|
|
outputs = p.map(compare_files, new_ordered_files)
|
|
for output in outputs:
|
|
non_matching_files += output
|
|
|
|
if non_matching_files and not args.fix_and_overwrite:
|
|
raise ValueError("Some diff and their modeling code did not match.")
|
|
|
|
if skipped_models:
|
|
console.print(
|
|
f"[bold green]Skipped {len(skipped_models)} models and their dependencies that are not in the diff: "
|
|
f"{', '.join(sorted(skipped_models))}[/bold green]"
|
|
)
|