import argparse import difflib import glob import logging from io import StringIO # Console for rich printing from modular_model_converter import convert_modular_file from rich.console import Console from rich.syntax import Syntax logging.basicConfig() logging.getLogger().setLevel(logging.ERROR) console = Console() def process_file(modular_file_path, generated_modeling_content, file_type="modeling_", fix_and_overwrite=False): file_path = modular_file_path.replace("modular_", f"{file_type}_") # Read the actual modeling file with open(file_path, "r") as modeling_file: content = modeling_file.read() output_buffer = StringIO(generated_modeling_content[file_type][0]) output_buffer.seek(0) output_content = output_buffer.read() diff = difflib.unified_diff( output_content.splitlines(), content.splitlines(), fromfile=f"{file_path}_generated", tofile=f"{file_path}", lineterm="", ) diff_list = list(diff) # Check for differences if diff_list: if fix_and_overwrite: with open(file_path, "w") as modeling_file: modeling_file.write(generated_modeling_content[file_type][0]) console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]") else: console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n") diff_text = "\n".join(diff_list) syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True) console.print(syntax) return 1 else: console.print(f"[bold green]No differences found for {file_path}.[/bold green]") return 0 def compare_files(modular_file_path, fix_and_overwrite=False): # Generate the expected modeling content generated_modeling_content = convert_modular_file(modular_file_path) diff = 0 for file_type in generated_modeling_content.keys(): diff += process_file(modular_file_path, generated_modeling_content, file_type, fix_and_overwrite) return diff if __name__ == "__main__": parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.") parser.add_argument( "--files", default=["all"], type=list, nargs="+", help="List of modular_xxx.py files to compare." ) parser.add_argument( "--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found." ) args = parser.parse_args() if args.files == ["all"]: args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) non_matching_files = 0 for modular_file_path in args.files: non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite) if non_matching_files and not args.fix_and_overwrite: raise ValueError("Some diff and their modeling code did not match.")