From bfe46c98b5d35e91d8c9e625fc12ae7315a152db Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 13 Feb 2025 17:25:17 +0100 Subject: [PATCH] Make `check_repository_consistency` run faster by MP (#36175) * speeddddd * speeddddd * speeddddd * speeddddd --------- Co-authored-by: ydshieh --- .circleci/config.yml | 2 +- utils/check_modular_conversion.py | 40 ++++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index dbbebe9fc06..7e497d755a1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -170,7 +170,7 @@ jobs: - store_artifacts: path: ~/transformers/installed.txt - run: python utils/check_copies.py - - run: python utils/check_modular_conversion.py + - run: python utils/check_modular_conversion.py --num_workers 4 - run: python utils/check_table.py - run: python utils/check_dummies.py - run: python utils/check_repo.py diff --git a/utils/check_modular_conversion.py b/utils/check_modular_conversion.py index e08621b5c32..c12fc90dc1f 100644 --- a/utils/check_modular_conversion.py +++ b/utils/check_modular_conversion.py @@ -120,6 +120,12 @@ if __name__ == "__main__": parser.add_argument( "--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found." ) + parser.add_argument( + "--num_workers", + default=1, + type=int, + help="The number of workers to run. No effect if `fix_and_overwrite` is specified.", + ) args = parser.parse_args() if args.files == ["all"]: args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) @@ -135,14 +141,32 @@ if __name__ == "__main__": skipped_models = set() non_matching_files = 0 ordered_files, dependencies = find_priority_list(args.files) - for modular_file_path in ordered_files: - is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff) - if is_guaranteed_no_diff: - model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "") - skipped_models.add(model_name) - continue - non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite) - models_in_diff = get_models_in_diff() # When overwriting, the diff changes + if args.fix_and_overwrite or args.num_workers == 1: + for modular_file_path in ordered_files: + is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff) + if is_guaranteed_no_diff: + model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "") + skipped_models.add(model_name) + continue + non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite) + models_in_diff = get_models_in_diff() # When overwriting, the diff changes + else: + new_ordered_files = [] + for modular_file_path in ordered_files: + is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff) + if is_guaranteed_no_diff: + model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "") + skipped_models.add(model_name) + else: + new_ordered_files.append(modular_file_path) + + new_ordered_files = ordered_files + import multiprocessing + + with multiprocessing.Pool(4) as p: + outputs = p.map(compare_files, new_ordered_files) + for output in outputs: + non_matching_files += output if non_matching_files and not args.fix_and_overwrite: raise ValueError("Some diff and their modeling code did not match.")