transformers/utils/check_modular_conversion.py
Pavel Iakubovskii fe1a5b73e6
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
[modular] speedup check_modular_conversion with multiprocessing (#37456)
* Change topological sort to return level-based output (lists of lists)

* Update main for modular converter

* Update test

* update check_modular_conversion

* Update gitignore

* Fix missing conversion for glm4

* Update

* Fix error msg

* Fixup

* fix docstring

* update docs

* Add comment

* delete qwen3_moe
2025-07-10 19:07:59 +01:00

233 lines
9.6 KiB
Python

import argparse
import difflib
import glob
import logging
import multiprocessing
import os
import shutil
import subprocess
from functools import partial
from io import StringIO
from create_dependency_mapping import find_priority_list
# Console for rich printing
from modular_model_converter import convert_modular_file
from rich.console import Console
from rich.syntax import Syntax
logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
console = Console()
BACKUP_EXT = ".modular_backup"
def process_file(
modular_file_path,
generated_modeling_content,
file_type="modeling_",
show_diff=True,
):
file_name_prefix = file_type.split("*")[0]
file_name_suffix = file_type.split("*")[-1] if "*" in file_type else ""
file_path = modular_file_path.replace("modular_", f"{file_name_prefix}_").replace(".py", f"{file_name_suffix}.py")
# Read the actual modeling file
with open(file_path, "r", encoding="utf-8") as modeling_file:
content = modeling_file.read()
output_buffer = StringIO(generated_modeling_content[file_type][0])
output_buffer.seek(0)
output_content = output_buffer.read()
diff = difflib.unified_diff(
output_content.splitlines(),
content.splitlines(),
fromfile=f"{file_path}_generated",
tofile=f"{file_path}",
lineterm="",
)
diff_list = list(diff)
# Check for differences
if diff_list:
# first save the copy of the original file, to be able to restore it later
if os.path.exists(file_path):
shutil.copy(file_path, file_path + BACKUP_EXT)
# we always save the generated content, to be able to update dependant files
with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file:
modeling_file.write(generated_modeling_content[file_type][0])
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
if show_diff:
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
diff_text = "\n".join(diff_list)
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
console.print(syntax)
return 1
else:
console.print(f"[bold green]No differences found for {file_path}.[/bold green]")
return 0
def compare_files(modular_file_path, show_diff=True):
# Generate the expected modeling content
generated_modeling_content = convert_modular_file(modular_file_path)
diff = 0
for file_type in generated_modeling_content.keys():
diff += process_file(modular_file_path, generated_modeling_content, file_type, show_diff)
return diff
def get_models_in_diff():
"""
Finds all models that have been modified in the diff.
Returns:
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
"""
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
modified_files = (
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
.decode("utf-8")
.split()
)
# Matches both modelling files and tests
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
model_names = set()
for file_path in relevant_modified_files:
model_name = file_path.split("/")[-2]
model_names.add(model_name)
return model_names
def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
"""
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
Model is in the diff -> not guaranteed to have no differences
Dependency is in the diff -> not guaranteed to have no differences
Otherwise -> guaranteed to have no differences
Args:
modular_file_path: The path to the modular file.
dependencies: A dictionary containing the dependencies of each modular file.
models_in_diff: A set containing the names of the models that have been modified.
Returns:
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
"""
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
if model_name in models_in_diff:
return False
for dep in dependencies[modular_file_path]:
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
dependency_model_name = dep.split(".")[-2]
if dependency_model_name in models_in_diff:
return False
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
parser.add_argument(
"--files", default=["all"], type=str, nargs="+", help="List of modular_xxx.py files to compare."
)
parser.add_argument(
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
)
parser.add_argument("--check_all", action="store_true", help="Check all files, not just the ones in the diff.")
parser.add_argument(
"--num_workers",
default=-1,
type=int,
help="The number of workers to run. Default is -1, which means the number of CPU cores.",
)
args = parser.parse_args()
if args.files == ["all"]:
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
if args.num_workers == -1:
args.num_workers = multiprocessing.cpu_count()
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
# script will do nothing.
current_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip()
if current_branch == "main":
console.print(
"[bold red]You are developing on the main branch. We cannot identify the list of changed files and will have to check all files. This may take a while.[/bold red]"
)
models_in_diff = {file_path.split("/")[-2] for file_path in args.files}
else:
models_in_diff = get_models_in_diff()
if not models_in_diff and not args.check_all:
console.print(
"[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]"
)
exit(0)
skipped_models = set()
non_matching_files = []
ordered_files, dependencies = find_priority_list(args.files)
flat_ordered_files = [item for sublist in ordered_files for item in sublist]
# ordered_files is a *sorted* list of lists of filepaths
# - files from the first list do NOT depend on other files
# - files in the second list depend on files from the first list
# - files in the third list depend on files from the second and (optionally) the first list
# - ... and so on
# files (models) within the same list are *independent* of each other;
# we start applying modular conversion to each list in parallel, starting from the first list
console.print(f"[bold yellow]Number of dependency levels: {len(ordered_files)}[/bold yellow]")
console.print(f"[bold yellow]Files per level: {tuple([len(x) for x in ordered_files])}[/bold yellow]")
try:
for dependency_level_files in ordered_files:
# Filter files guaranteed no diff
files_to_check = []
for file_path in dependency_level_files:
if not args.check_all and guaranteed_no_diff(file_path, dependencies, models_in_diff):
skipped_models.add(file_path.split("/")[-2]) # save model folder name
else:
files_to_check.append(file_path)
if not files_to_check:
continue
# Process files with diff
num_workers = min(args.num_workers, len(files_to_check))
with multiprocessing.Pool(num_workers) as p:
is_changed_flags = p.map(
partial(compare_files, show_diff=not args.fix_and_overwrite),
files_to_check,
)
# Collect changed files and their original paths
for is_changed, file_path in zip(is_changed_flags, files_to_check):
if is_changed:
non_matching_files.append(file_path)
# Update changed models, after each round of conversions
# (save model folder name)
models_in_diff.add(file_path.split("/")[-2])
finally:
# Restore overwritten files by modular (if needed)
backup_files = glob.glob("**/*" + BACKUP_EXT, recursive=True)
for backup_file_path in backup_files:
overwritten_path = backup_file_path.replace(BACKUP_EXT, "")
if not args.fix_and_overwrite and os.path.exists(overwritten_path):
shutil.copy(backup_file_path, overwritten_path)
os.remove(backup_file_path)
if non_matching_files and not args.fix_and_overwrite:
diff_models = set(file_path.split("/")[-2] for file_path in non_matching_files) # noqa
models_str = "\n - " + "\n - ".join(sorted(diff_models))
raise ValueError(f"Some diff and their modeling code did not match. Models in diff:{models_str}")
if skipped_models:
console.print(
f"[bold green]Skipped {len(skipped_models)} models and their dependencies that are not in the diff: "
f"{', '.join(sorted(skipped_models))}[/bold green]"
)