mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00

* pr * pr * pr * pr * pr * pr * pr * pr * pr --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
134 lines
5.3 KiB
Python
134 lines
5.3 KiB
Python
import argparse
|
|
import json
|
|
import re
|
|
import string
|
|
|
|
|
|
MAX_NUM_JOBS_TO_SUGGEST = 16
|
|
|
|
|
|
def get_jobs_to_run():
|
|
# The file `pr_files.txt` contains the information about the files changed in a pull request, and it is prepared by
|
|
# the caller (using GitHub api).
|
|
# We can also use the following api to get the information if we don't have them before calling this script.
|
|
# url = f"https://api.github.com/repos/huggingface/transformers/pulls/PULL_NUMBER/files?ref={pr_sha}"
|
|
with open("pr_files.txt") as fp:
|
|
pr_files = json.load(fp)
|
|
pr_files = [{k: v for k, v in item.items() if k in ["filename", "status"]} for item in pr_files]
|
|
pr_files = [item["filename"] for item in pr_files if item["status"] in ["added", "modified"]]
|
|
|
|
# models or quantizers
|
|
re_1 = re.compile(r"src/transformers/(models/.*)/modeling_.*\.py")
|
|
re_2 = re.compile(r"src/transformers/(quantizers/quantizer_.*)\.py")
|
|
|
|
# tests for models or quantizers
|
|
re_3 = re.compile(r"tests/(models/.*)/test_.*\.py")
|
|
re_4 = re.compile(r"tests/(quantization/.*)/test_.*\.py")
|
|
|
|
# files in a model directory but not necessary a modeling file
|
|
re_5 = re.compile(r"src/transformers/(models/.*)/.*\.py")
|
|
|
|
regexes = [re_1, re_2, re_3, re_4, re_5]
|
|
|
|
jobs_to_run = []
|
|
for pr_file in pr_files:
|
|
for regex in regexes:
|
|
matched = regex.findall(pr_file)
|
|
if len(matched) > 0:
|
|
item = matched[0]
|
|
item = item.replace("quantizers/quantizer_", "quantization/")
|
|
# TODO: for files in `quantizers`, the processed item above may not exist. Try using a fuzzy matching
|
|
if item in repo_content:
|
|
jobs_to_run.append(item)
|
|
break
|
|
jobs_to_run = sorted(set(jobs_to_run))
|
|
|
|
return jobs_to_run
|
|
|
|
|
|
def parse_message(message: str) -> str:
|
|
"""
|
|
Parses a GitHub pull request's comment to find the models specified in it to run slow CI.
|
|
|
|
Args:
|
|
message (`str`): The body of a GitHub pull request's comment.
|
|
|
|
Returns:
|
|
`str`: The substring in `message` after `run-slow`, run_slow` or run slow`. If no such prefix is found, the
|
|
empty string is returned.
|
|
"""
|
|
if message is None:
|
|
return ""
|
|
|
|
message = message.strip().lower()
|
|
|
|
# run-slow: model_1, model_2, quantization_1, quantization_2
|
|
if not message.startswith(("run-slow", "run_slow", "run slow")):
|
|
return ""
|
|
message = message[len("run slow") :]
|
|
# remove leading `:`
|
|
while message.strip().startswith(":"):
|
|
message = message.strip()[1:]
|
|
|
|
return message
|
|
|
|
|
|
def get_jobs(message: str):
|
|
models = parse_message(message)
|
|
return models.replace(",", " ").split()
|
|
|
|
|
|
def check_name(model_name: str):
|
|
allowed = string.ascii_letters + string.digits + "_"
|
|
return not (model_name.startswith("_") or model_name.endswith("_")) and all(c in allowed for c in model_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--message", type=str, default="", help="The content of a comment.")
|
|
parser.add_argument("--quantization", action="store_true", help="If we collect quantization tests")
|
|
args = parser.parse_args()
|
|
|
|
# The files are prepared by the caller (using GitHub api).
|
|
# We can also use the following api to get the information if we don't have them before calling this script.
|
|
# url = f"https://api.github.com/repos/OWNER/REPO/contents/PATH?ref={pr_sha}"
|
|
# (we avoid to checkout the repository using `actions/checkout` to reduce the run time, but mostly to avoid the potential security issue as much as possible)
|
|
repo_content = []
|
|
for filename in ["tests_dir.txt", "tests_models_dir.txt", "tests_quantization_dir.txt"]:
|
|
with open(filename) as fp:
|
|
data = json.load(fp)
|
|
data = [item["path"][len("tests/") :] for item in data if item["type"] == "dir"]
|
|
repo_content.extend(data)
|
|
|
|
# These don't have the prefix `models/` or `quantization/`, so we need to add them.
|
|
if args.message:
|
|
specified_jobs = get_jobs(args.message)
|
|
specified_jobs = [job for job in specified_jobs if check_name(job)]
|
|
|
|
# Add prefix (`models/` or `quantization`)
|
|
jobs_to_run = []
|
|
for job in specified_jobs:
|
|
if not args.quantization:
|
|
if f"models/{job}" in repo_content:
|
|
jobs_to_run.append(f"models/{job}")
|
|
elif job in repo_content and job != "quantization":
|
|
jobs_to_run.append(job)
|
|
elif f"quantization/{job}" in repo_content:
|
|
jobs_to_run.append(f"quantization/{job}")
|
|
|
|
print(sorted(set(jobs_to_run)))
|
|
|
|
else:
|
|
# Compute (from the added/modified files) the directories under `tests/`, `tests/models/` and `tests/quantization`to run tests.
|
|
# These are already with the prefix `models/` or `quantization/`, so we don't need to add them.
|
|
jobs_to_run = get_jobs_to_run()
|
|
jobs_to_run = [x.replace("models/", "").replace("quantization/", "") for x in jobs_to_run]
|
|
jobs_to_run = [job for job in jobs_to_run if check_name(job)]
|
|
|
|
if len(jobs_to_run) > MAX_NUM_JOBS_TO_SUGGEST:
|
|
jobs_to_run = jobs_to_run[:MAX_NUM_JOBS_TO_SUGGEST]
|
|
|
|
suggestion = f"{', '.join(jobs_to_run)}"
|
|
|
|
print(suggestion)
|