transformers/utils/check_modular_conversion.py
Arthur 317e069ee7
Modular transformers: modularity and inheritance for new model additions (#33248)
* update exampel

* update

* push the converted diff files for testing and ci

* correct one example

* fix class attributes and docstring

* nits

* oups

* fixed config!

* update

* nitd

* class attributes are not matched against the other, this is missing

* fixed overwriting self.xxx now onto the attributes I think

* partial fix, now order with docstring

* fix docstring order?

* more fixes

* update

* fix missing docstrings!

* examples don't all work yet

* fixup

* nit

* updated

* hick

* update

* delete

* update

* update

* update

* fix

* all default

* no local import

* fix more diff

* some fix related to "safe imports"

* push fixed

* add helper!

* style

* add a check

* all by default

* add the

* update

* FINALLY!

* nit

* fix config dependencies

* man that is it

* fix fix

* update diffs

* fix the last issue

* re-default to all

* alll the fixes

* nice

* fix properties vs setter

* fixup

* updates

* update dependencies

* make sure to install what needs to be installed

* fixup

* quick fix for now

* fix!

* fixup

* update

* update

* updates

* whitespaces

* nit

* fix

* simplify everything, and make it file agnostic (should work for image processors)

* style

* finish fixing all import issues

* fixup

* empty modeling should not be written!

* Add logic to find who depends on what

* update

* cleanup

* update

* update gemma to support positions

* some small nits

* this is the correct docstring for gemma2

* fix merging of docstrings

* update

* fixup

* update

* take doc into account

* styling

* update

* fix hidden activation

* more fixes

* final fixes!

* fixup

* fixup instruct  blip video

* update

* fix bugs

* align gemma2 with the rest as well

* updats

* revert

* update

* more reversiom

* grind

* more

* arf

* update

* order will matter

* finish del stuff

* update

* rename to modular

* fixup

* nits

* update makefile

* fixup

* update order of the checks!

* fix

* fix docstring that has a call inside

* fiix conversion check

* style

* add some initial documentation

* update

* update doc

* some fixup

* updates

* yups

* Mostly todo gimme a minut

* update

* fixup

* revert some stuff

* Review docs for the modular transformers (#33472)

Docs

* good update

* fixup

* mmm current updates lead to this code

* okay, this fixes it

* cool

* fixes

* update

* nit

* updates

* nits

* fix doc

* update

* revert bad changes

* update

* updates

* proper update

* update

* update?

* up

* update

* cool

* nits

* nits

* bon bon

* fix

* ?

* minimise changes

* update

* update

* update

* updates?

* fixed gemma2

* kind of a hack

* nits

* update

* remove `diffs` in favor of `modular`

* fix make fix copies

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
2024-09-24 15:54:07 +02:00

77 lines
2.9 KiB
Python

import argparse
import difflib
import glob
import logging
from io import StringIO
# Console for rich printing
from modular_model_converter import convert_modular_file
from rich.console import Console
from rich.syntax import Syntax
logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
console = Console()
def process_file(modular_file_path, generated_modeling_content, file_type="modeling_", fix_and_overwrite=False):
file_path = modular_file_path.replace("modular_", f"{file_type}_")
# Read the actual modeling file
with open(file_path, "r") as modeling_file:
content = modeling_file.read()
output_buffer = StringIO(generated_modeling_content[file_type][0])
output_buffer.seek(0)
output_content = output_buffer.read()
diff = difflib.unified_diff(
output_content.splitlines(),
content.splitlines(),
fromfile=f"{file_path}_generated",
tofile=f"{file_path}",
lineterm="",
)
diff_list = list(diff)
# Check for differences
if diff_list:
if fix_and_overwrite:
with open(file_path, "w") as modeling_file:
modeling_file.write(generated_modeling_content[file_type][0])
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
else:
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
diff_text = "\n".join(diff_list)
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
console.print(syntax)
return 1
else:
console.print(f"[bold green]No differences found for {file_path}.[/bold green]")
return 0
def compare_files(modular_file_path, fix_and_overwrite=False):
# Generate the expected modeling content
generated_modeling_content = convert_modular_file(modular_file_path)
diff = 0
for file_type in generated_modeling_content.keys():
diff += process_file(modular_file_path, generated_modeling_content, file_type, fix_and_overwrite)
return diff
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
parser.add_argument(
"--files", default=["all"], type=list, nargs="+", help="List of modular_xxx.py files to compare."
)
parser.add_argument(
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
)
args = parser.parse_args()
if args.files == ["all"]:
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
non_matching_files = 0
for modular_file_path in args.files:
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
if non_matching_files and not args.fix_and_overwrite:
raise ValueError("Some diff and their modeling code did not match.")