mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00

* rework converter * Update modular_model_converter.py * Update modular_model_converter.py * Update modular_model_converter.py * Update modular_model_converter.py * cleaning * cleaning * finalize imports * imports * Update modular_model_converter.py * Better renaming to avoid visiting same file multiple times * start converting files * style * address most comments * style * remove unused stuff in get_needed_imports * style * move class dependency functions outside class * Move main functions outside class * style * Update modular_model_converter.py * rename func * add augmented dependencies * Update modular_model_converter.py * Add types_to_file_type + tweak annotation handling * Allow assignment dependency mapping + fix regex * style + update modular examples * fix modular_roberta example (wrong redefinition of __init__) * slightly correct order in which dependencies will appear * style * review comments * Performance + better handling of dependencies when they are imported * style * Add advanced new classes capabilities * style * add forgotten check * Update modeling_llava_next_video.py * Add prority list ordering in check_conversion as well * Update check_modular_conversion.py * Update configuration_gemma.py
79 lines
3.0 KiB
Python
79 lines
3.0 KiB
Python
import argparse
|
|
import difflib
|
|
import glob
|
|
import logging
|
|
from io import StringIO
|
|
|
|
from create_dependency_mapping import find_priority_list
|
|
|
|
# Console for rich printing
|
|
from modular_model_converter import convert_modular_file
|
|
from rich.console import Console
|
|
from rich.syntax import Syntax
|
|
|
|
|
|
logging.basicConfig()
|
|
logging.getLogger().setLevel(logging.ERROR)
|
|
console = Console()
|
|
|
|
|
|
def process_file(modular_file_path, generated_modeling_content, file_type="modeling_", fix_and_overwrite=False):
|
|
file_path = modular_file_path.replace("modular_", f"{file_type}_")
|
|
# Read the actual modeling file
|
|
with open(file_path, "r") as modeling_file:
|
|
content = modeling_file.read()
|
|
output_buffer = StringIO(generated_modeling_content[file_type][0])
|
|
output_buffer.seek(0)
|
|
output_content = output_buffer.read()
|
|
diff = difflib.unified_diff(
|
|
output_content.splitlines(),
|
|
content.splitlines(),
|
|
fromfile=f"{file_path}_generated",
|
|
tofile=f"{file_path}",
|
|
lineterm="",
|
|
)
|
|
diff_list = list(diff)
|
|
# Check for differences
|
|
if diff_list:
|
|
if fix_and_overwrite:
|
|
with open(file_path, "w") as modeling_file:
|
|
modeling_file.write(generated_modeling_content[file_type][0])
|
|
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
|
|
else:
|
|
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
|
|
diff_text = "\n".join(diff_list)
|
|
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
|
|
console.print(syntax)
|
|
return 1
|
|
else:
|
|
console.print(f"[bold green]No differences found for {file_path}.[/bold green]")
|
|
return 0
|
|
|
|
|
|
def compare_files(modular_file_path, fix_and_overwrite=False):
|
|
# Generate the expected modeling content
|
|
generated_modeling_content = convert_modular_file(modular_file_path)
|
|
diff = 0
|
|
for file_type in generated_modeling_content.keys():
|
|
diff += process_file(modular_file_path, generated_modeling_content, file_type, fix_and_overwrite)
|
|
return diff
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
|
|
parser.add_argument(
|
|
"--files", default=["all"], type=list, nargs="+", help="List of modular_xxx.py files to compare."
|
|
)
|
|
parser.add_argument(
|
|
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
|
|
)
|
|
args = parser.parse_args()
|
|
if args.files == ["all"]:
|
|
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
|
non_matching_files = 0
|
|
for modular_file_path in find_priority_list(args.files):
|
|
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
|
|
|
if non_matching_files and not args.fix_and_overwrite:
|
|
raise ValueError("Some diff and their modeling code did not match.")
|