mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Make check_repository_consistency
run faster by MP (#36175)
* speeddddd * speeddddd * speeddddd * speeddddd --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
5f0fd1185b
commit
bfe46c98b5
@ -170,7 +170,7 @@ jobs:
|
||||
- store_artifacts:
|
||||
path: ~/transformers/installed.txt
|
||||
- run: python utils/check_copies.py
|
||||
- run: python utils/check_modular_conversion.py
|
||||
- run: python utils/check_modular_conversion.py --num_workers 4
|
||||
- run: python utils/check_table.py
|
||||
- run: python utils/check_dummies.py
|
||||
- run: python utils/check_repo.py
|
||||
|
@ -120,6 +120,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
default=1,
|
||||
type=int,
|
||||
help="The number of workers to run. No effect if `fix_and_overwrite` is specified.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files == ["all"]:
|
||||
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||
@ -135,14 +141,32 @@ if __name__ == "__main__":
|
||||
skipped_models = set()
|
||||
non_matching_files = 0
|
||||
ordered_files, dependencies = find_priority_list(args.files)
|
||||
for modular_file_path in ordered_files:
|
||||
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
|
||||
if is_guaranteed_no_diff:
|
||||
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
||||
skipped_models.add(model_name)
|
||||
continue
|
||||
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
||||
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
|
||||
if args.fix_and_overwrite or args.num_workers == 1:
|
||||
for modular_file_path in ordered_files:
|
||||
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
|
||||
if is_guaranteed_no_diff:
|
||||
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
||||
skipped_models.add(model_name)
|
||||
continue
|
||||
non_matching_files += compare_files(modular_file_path, args.fix_and_overwrite)
|
||||
models_in_diff = get_models_in_diff() # When overwriting, the diff changes
|
||||
else:
|
||||
new_ordered_files = []
|
||||
for modular_file_path in ordered_files:
|
||||
is_guaranteed_no_diff = guaranteed_no_diff(modular_file_path, dependencies, models_in_diff)
|
||||
if is_guaranteed_no_diff:
|
||||
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
||||
skipped_models.add(model_name)
|
||||
else:
|
||||
new_ordered_files.append(modular_file_path)
|
||||
|
||||
new_ordered_files = ordered_files
|
||||
import multiprocessing
|
||||
|
||||
with multiprocessing.Pool(4) as p:
|
||||
outputs = p.map(compare_files, new_ordered_files)
|
||||
for output in outputs:
|
||||
non_matching_files += output
|
||||
|
||||
if non_matching_files and not args.fix_and_overwrite:
|
||||
raise ValueError("Some diff and their modeling code did not match.")
|
||||
|
Loading…
Reference in New Issue
Block a user