Make check_repository_consistency run faster by MP (#36175)

* speeddddd

* speeddddd

* speeddddd

* speeddddd

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2025-02-13 17:25:17 +01:00 committed by GitHub
parent 5f0fd1185b
commit bfe46c98b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 9 deletions

View File

@ -170,7 +170,7 @@ jobs:
- store_artifacts:
path: ~/transformers/installed.txt
- run: python utils/check_copies.py
- run: python utils/check_modular_conversion.py
- run: python utils/check_modular_conversion.py --num_workers 4
- run: python utils/check_table.py
- run: python utils/check_dummies.py
- run: python utils/check_repo.py

View File

@ -120,6 +120,12 @@ if __name__ == "__main__":
parser.add_argument(
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
)
parser.add_argument(
"--num_workers",
default=1,
type=int,
help="The number of workers to run. No effect if `fix_and_overwrite` is specified.",
)
args = parser.parse_args()
if args.files == ["all"]:
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
@ -135,14 +141,32 @@ if __name__ == "__main__":
skipped_models = set()
non_matching_files = 0
ordered_files, dependencies = find_priority_list(args.files)
for modular_file_path in ordered_files:
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
if is_guaranteed_no_diff:
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
skipped_models.add(model_name)
continue
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
if args.fix_and_overwrite or args.num_workers == 1:
for modular_file_path in ordered_files:
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
if is_guaranteed_no_diff:
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
skipped_models.add(model_name)
continue
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
else:
new_ordered_files = []
for modular_file_path in ordered_files:
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
if is_guaranteed_no_diff:
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
skipped_models.add(model_name)
else:
new_ordered_files.append(modular_file_path)
new_ordered_files = ordered_files
import multiprocessing
with multiprocessing.Pool(4) as p:
outputs = p.map(compare_files, new_ordered_files)
for output in outputs:
non_matching_files += output
if non_matching_files and not args.fix_and_overwrite:
raise ValueError("Some diff and their modeling code did not match.")