diff --git a/.github/workflows/get-pr-info.yml b/.github/workflows/get-pr-info.yml new file mode 100644 index 00000000000..989281e5b90 --- /dev/null +++ b/.github/workflows/get-pr-info.yml @@ -0,0 +1,157 @@ +name: Get PR commit SHA +on: + workflow_call: + inputs: + pr_number: + required: true + type: string + outputs: + PR_HEAD_REPO_FULL_NAME: + description: "The full name of the repository from which the pull request is created" + value: ${{ jobs.get-pr-info.outputs.PR_HEAD_REPO_FULL_NAME }} + PR_BASE_REPO_FULL_NAME: + description: "The full name of the repository to which the pull request is created" + value: ${{ jobs.get-pr-info.outputs.PR_BASE_REPO_FULL_NAME }} + PR_HEAD_REPO_OWNER: + description: "The owner of the repository from which the pull request is created" + value: ${{ jobs.get-pr-info.outputs.PR_HEAD_REPO_OWNER }} + PR_BASE_REPO_OWNER: + description: "The owner of the repository to which the pull request is created" + value: ${{ jobs.get-pr-info.outputs.PR_BASE_REPO_OWNER }} + PR_HEAD_REPO_NAME: + description: "The name of the repository from which the pull request is created" + value: ${{ jobs.get-pr-info.outputs.PR_HEAD_REPO_NAME }} + PR_BASE_REPO_NAME: + description: "The name of the repository to which the pull request is created" + value: ${{ jobs.get-pr-info.outputs.PR_BASE_REPO_NAME }} + PR_HEAD_REF: + description: "The branch name of the pull request in the head repository" + value: ${{ jobs.get-pr-info.outputs.PR_HEAD_REF }} + PR_BASE_REF: + description: "The branch name in the base repository (to merge into)" + value: ${{ jobs.get-pr-info.outputs.PR_BASE_REF }} + PR_HEAD_SHA: + description: "The head sha of the pull request branch in the head repository" + value: ${{ jobs.get-pr-info.outputs.PR_HEAD_SHA }} + PR_BASE_SHA: + description: "The head sha of the target branch in the base repository" + value: ${{ jobs.get-pr-info.outputs.PR_BASE_SHA }} + PR_MERGE_COMMIT_SHA: + description: "The sha of the merge commit for the pull request (created by GitHub) in the base repository" + value: ${{ jobs.get-pr-info.outputs.PR_MERGE_COMMIT_SHA }} + PR_HEAD_COMMIT_DATE: + description: "The date of the head sha of the pull request branch in the head repository" + value: ${{ jobs.get-pr-info.outputs.PR_HEAD_COMMIT_DATE }} + PR_MERGE_COMMIT_DATE: + description: "The date of the merge commit for the pull request (created by GitHub) in the base repository" + value: ${{ jobs.get-pr-info.outputs.PR_MERGE_COMMIT_DATE }} + PR_HEAD_COMMIT_TIMESTAMP: + description: "The timestamp of the head sha of the pull request branch in the head repository" + value: ${{ jobs.get-pr-info.outputs.PR_HEAD_COMMIT_TIMESTAMP }} + PR_MERGE_COMMIT_TIMESTAMP: + description: "The timestamp of the merge commit for the pull request (created by GitHub) in the base repository" + value: ${{ jobs.get-pr-info.outputs.PR_MERGE_COMMIT_TIMESTAMP }} + PR: + description: "The PR" + value: ${{ jobs.get-pr-info.outputs.PR }} + PR_FILES: + description: "The files touched in the PR" + value: ${{ jobs.get-pr-info.outputs.PR_FILES }} + + +jobs: + get-pr-info: + runs-on: ubuntu-22.04 + name: Get PR commit SHA better + outputs: + PR_HEAD_REPO_FULL_NAME: ${{ steps.pr_info.outputs.head_repo_full_name }} + PR_BASE_REPO_FULL_NAME: ${{ steps.pr_info.outputs.base_repo_full_name }} + PR_HEAD_REPO_OWNER: ${{ steps.pr_info.outputs.head_repo_owner }} + PR_BASE_REPO_OWNER: ${{ steps.pr_info.outputs.base_repo_owner }} + PR_HEAD_REPO_NAME: ${{ steps.pr_info.outputs.head_repo_name }} + PR_BASE_REPO_NAME: ${{ steps.pr_info.outputs.base_repo_name }} + PR_HEAD_REF: ${{ steps.pr_info.outputs.head_ref }} + PR_BASE_REF: ${{ steps.pr_info.outputs.base_ref }} + PR_HEAD_SHA: ${{ steps.pr_info.outputs.head_sha }} + PR_BASE_SHA: ${{ steps.pr_info.outputs.base_sha }} + PR_MERGE_COMMIT_SHA: ${{ steps.pr_info.outputs.merge_commit_sha }} + PR_HEAD_COMMIT_DATE: ${{ steps.pr_info.outputs.head_commit_date }} + PR_MERGE_COMMIT_DATE: ${{ steps.pr_info.outputs.merge_commit_date }} + PR_HEAD_COMMIT_TIMESTAMP: ${{ steps.get_timestamps.outputs.head_commit_timestamp }} + PR_MERGE_COMMIT_TIMESTAMP: ${{ steps.get_timestamps.outputs.merge_commit_timestamp }} + PR: ${{ steps.pr_info.outputs.pr }} + PR_FILES: ${{ steps.pr_info.outputs.files }} + if: ${{ inputs.pr_number != '' }} + steps: + - name: Extract PR details + id: pr_info + uses: actions/github-script@v6 + with: + script: | + const { data: pr } = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: ${{ inputs.pr_number }} + }); + + const { data: head_commit } = await github.rest.repos.getCommit({ + owner: pr.head.repo.owner.login, + repo: pr.head.repo.name, + ref: pr.head.ref + }); + + const { data: merge_commit } = await github.rest.repos.getCommit({ + owner: pr.base.repo.owner.login, + repo: pr.base.repo.name, + ref: pr.merge_commit_sha, + }); + + const { data: files } = await github.rest.pulls.listFiles({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: ${{ inputs.pr_number }} + }); + + core.setOutput('head_repo_full_name', pr.head.repo.full_name); + core.setOutput('base_repo_full_name', pr.base.repo.full_name); + core.setOutput('head_repo_owner', pr.head.repo.owner.login); + core.setOutput('base_repo_owner', pr.base.repo.owner.login); + core.setOutput('head_repo_name', pr.head.repo.name); + core.setOutput('base_repo_name', pr.base.repo.name); + core.setOutput('head_ref', pr.head.ref); + core.setOutput('base_ref', pr.base.ref); + core.setOutput('head_sha', pr.head.sha); + core.setOutput('base_sha', pr.base.sha); + core.setOutput('merge_commit_sha', pr.merge_commit_sha); + core.setOutput('pr', pr); + + core.setOutput('head_commit_date', head_commit.commit.committer.date); + core.setOutput('merge_commit_date', merge_commit.commit.committer.date); + + core.setOutput('files', files); + + console.log('PR head commit:', { + head_commit: head_commit, + commit: head_commit.commit, + date: head_commit.commit.committer.date + }); + + console.log('PR merge commit:', { + merge_commit: merge_commit, + commit: merge_commit.commit, + date: merge_commit.commit.committer.date + }); + + - name: Convert dates to timestamps + id: get_timestamps + run: | + head_commit_date=${{ steps.pr_info.outputs.head_commit_date }} + merge_commit_date=${{ steps.pr_info.outputs.merge_commit_date }} + echo $head_commit_date + echo $merge_commit_date + head_commit_timestamp=$(date -d "$head_commit_date" +%s) + merge_commit_timestamp=$(date -d "$merge_commit_date" +%s) + echo $head_commit_timestamp + echo $merge_commit_timestamp + echo "head_commit_timestamp=$head_commit_timestamp" >> $GITHUB_OUTPUT + echo "merge_commit_timestamp=$merge_commit_timestamp" >> $GITHUB_OUTPUT diff --git a/.github/workflows/get-pr-number.yml b/.github/workflows/get-pr-number.yml new file mode 100644 index 00000000000..316b0f7503f --- /dev/null +++ b/.github/workflows/get-pr-number.yml @@ -0,0 +1,36 @@ +name: Get PR number +on: + workflow_call: + outputs: + PR_NUMBER: + description: "The extracted PR number" + value: ${{ jobs.get-pr-number.outputs.PR_NUMBER }} + +jobs: + get-pr-number: + runs-on: ubuntu-22.04 + name: Get PR number + outputs: + PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }} + steps: + - name: Get PR number + shell: bash + run: | + if [[ "${{ github.event.issue.number }}" != "" && "${{ github.event.issue.pull_request }}" != "" ]]; then + echo "PR_NUMBER=${{ github.event.issue.number }}" >> $GITHUB_ENV + elif [[ "${{ github.event.pull_request.number }}" != "" ]]; then + echo "PR_NUMBER=${{ github.event.pull_request.number }}" >> $GITHUB_ENV + elif [[ "${{ github.event.pull_request }}" != "" ]]; then + echo "PR_NUMBER=${{ github.event.number }}" >> $GITHUB_ENV + else + echo "PR_NUMBER=" >> $GITHUB_ENV + fi + + - name: Check PR number + shell: bash + run: | + echo "${{ env.PR_NUMBER }}" + + - name: Set PR number + id: set_pr_number + run: echo "PR_NUMBER=${{ env.PR_NUMBER }}" >> "$GITHUB_OUTPUT" diff --git a/.github/workflows/pr_run_slow_ci.yml b/.github/workflows/pr_run_slow_ci.yml new file mode 100644 index 00000000000..f3070a6f4d2 --- /dev/null +++ b/.github/workflows/pr_run_slow_ci.yml @@ -0,0 +1,163 @@ +name: PR slow CI +on: + pull_request_target: + types: [opened, synchronize, reopened] + +jobs: + get-pr-number: + name: Get PR number + uses: ./.github/workflows/get-pr-number.yml + + get-pr-info: + name: Get PR commit SHA + needs: get-pr-number + if: ${{ needs.get-pr-number.outputs.PR_NUMBER != ''}} + uses: ./.github/workflows/get-pr-info.yml + with: + pr_number: ${{ needs.get-pr-number.outputs.PR_NUMBER }} + + # We only need to verify the timestamp if the workflow is triggered by `issue_comment`. + verity_pr_commit: + name: Verity PR commit corresponds to a specific event by comparing timestamps + if: ${{ github.event.comment.created_at != '' }} + runs-on: ubuntu-22.04 + needs: get-pr-info + env: + COMMENT_DATE: ${{ github.event.comment.created_at }} + PR_MERGE_COMMIT_DATE: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_DATE }} + PR_MERGE_COMMIT_TIMESTAMP: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_TIMESTAMP }} + steps: + - run: | + COMMENT_TIMESTAMP=$(date -d "${COMMENT_DATE}" +"%s") + echo "COMMENT_DATE: $COMMENT_DATE" + echo "PR_MERGE_COMMIT_DATE: $PR_MERGE_COMMIT_DATE" + echo "COMMENT_TIMESTAMP: $COMMENT_TIMESTAMP" + echo "PR_MERGE_COMMIT_TIMESTAMP: $PR_MERGE_COMMIT_TIMESTAMP" + if [ $COMMENT_TIMESTAMP -le $PR_MERGE_COMMIT_TIMESTAMP ]; then + echo "Last commit on the pull request is newer than the issue comment triggering this run! Abort!"; + exit -1; + fi + + get-jobs: + name: Get test files to run + runs-on: ubuntu-22.04 + needs: [get-pr-number, get-pr-info] + outputs: + jobs: ${{ steps.get_jobs.outputs.jobs_to_run }} + steps: + - name: Get repository content + id: repo_content + uses: actions/github-script@v6 + with: + script: | + const { data: tests_dir } = await github.rest.repos.getContent({ + owner: '${{ needs.get-pr-info.outputs.PR_HEAD_REPO_OWNER }}', + repo: '${{ needs.get-pr-info.outputs.PR_HEAD_REPO_NAME }}', + path: 'tests', + ref: '${{ needs.get-pr-info.outputs.PR_HEAD_SHA }}', + }); + + const { data: tests_models_dir } = await github.rest.repos.getContent({ + owner: '${{ needs.get-pr-info.outputs.PR_HEAD_REPO_OWNER }}', + repo: '${{ needs.get-pr-info.outputs.PR_HEAD_REPO_NAME }}', + path: 'tests/models', + ref: '${{ needs.get-pr-info.outputs.PR_HEAD_SHA }}', + }); + + const { data: tests_quantization_dir } = await github.rest.repos.getContent({ + owner: '${{ needs.get-pr-info.outputs.PR_HEAD_REPO_OWNER }}', + repo: '${{ needs.get-pr-info.outputs.PR_HEAD_REPO_NAME }}', + path: 'tests/quantization', + ref: '${{ needs.get-pr-info.outputs.PR_HEAD_SHA }}', + }); + + core.setOutput('tests_dir', tests_dir); + core.setOutput('tests_models_dir', tests_models_dir); + core.setOutput('tests_quantization_dir', tests_quantization_dir); + + # This checkout to the main branch + - uses: actions/checkout@v4 + with: + fetch-depth: "0" + + - name: Write pr_files file + run: | + cat > pr_files.txt << 'EOF' + ${{ needs.get-pr-info.outputs.PR_FILES }} + EOF + + - name: Write tests_dir file + run: | + cat > tests_dir.txt << 'EOF' + ${{ steps.repo_content.outputs.tests_dir }} + EOF + + - name: Write tests_models_dir file + run: | + cat > tests_models_dir.txt << 'EOF' + ${{ steps.repo_content.outputs.tests_models_dir }} + EOF + + - name: Write tests_quantization_dir file + run: | + cat > tests_quantization_dir.txt << 'EOF' + ${{ steps.repo_content.outputs.tests_quantization_dir }} + EOF + + - name: Run script to get jobs to run + id: get_jobs + run: | + python utils/get_pr_run_slow_jobs.py | tee output.txt + echo "jobs_to_run: $(tail -n 1 output.txt)" + echo "jobs_to_run=$(tail -n 1 output.txt)" >> $GITHUB_OUTPUT + + send_comment: + name: Send a comment to suggest jobs to run + if: ${{ needs.get-jobs.outputs.jobs != '' }} + needs: [get-pr-number, get-jobs] + permissions: + pull-requests: write + runs-on: ubuntu-22.04 + steps: + - name: Delete existing comment and send new one + uses: actions/github-script@v7 + env: + BODY: "\n\nrun-slow: ${{ needs.get-jobs.outputs.jobs }}" + with: + script: | + const prNumber = ${{ needs.get-pr-number.outputs.PR_NUMBER }}; + const commentPrefix = "**[For maintainers]** Suggested jobs to run (before merge)"; + + // Get all comments on the PR + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber + }); + + // Find existing comment(s) that start with our prefix + const existingComments = comments.filter(comment => + comment.user.login === 'github-actions[bot]' && + comment.body.startsWith(commentPrefix) + ); + + // Delete existing comment(s) + for (const comment of existingComments) { + console.log(`Deleting existing comment #${comment.id}`); + await github.rest.issues.deleteComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: comment.id + }); + } + + // Create new comment + const newBody = `${commentPrefix}${process.env.BODY}`; + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + body: newBody + }); + + console.log('✅ Comment updated successfully'); \ No newline at end of file diff --git a/utils/get_pr_run_slow_jobs.py b/utils/get_pr_run_slow_jobs.py new file mode 100644 index 00000000000..fa56a6c305e --- /dev/null +++ b/utils/get_pr_run_slow_jobs.py @@ -0,0 +1,133 @@ +import argparse +import json +import re +import string + + +MAX_NUM_JOBS_TO_SUGGEST = 16 + + +def get_jobs_to_run(): + # The file `pr_files.txt` contains the information about the files changed in a pull request, and it is prepared by + # the caller (using GitHub api). + # We can also use the following api to get the information if we don't have them before calling this script. + # url = f"https://api.github.com/repos/huggingface/transformers/pulls/PULL_NUMBER/files?ref={pr_sha}" + with open("pr_files.txt") as fp: + pr_files = json.load(fp) + pr_files = [{k: v for k, v in item.items() if k in ["filename", "status"]} for item in pr_files] + pr_files = [item["filename"] for item in pr_files if item["status"] in ["added", "modified"]] + + # models or quantizers + re_1 = re.compile(r"src/transformers/(models/.*)/modeling_.*\.py") + re_2 = re.compile(r"src/transformers/(quantizers/quantizer_.*)\.py") + + # tests for models or quantizers + re_3 = re.compile(r"tests/(models/.*)/test_.*\.py") + re_4 = re.compile(r"tests/(quantization/.*)/test_.*\.py") + + # files in a model directory but not necessary a modeling file + re_5 = re.compile(r"src/transformers/(models/.*)/.*\.py") + + regexes = [re_1, re_2, re_3, re_4, re_5] + + jobs_to_run = [] + for pr_file in pr_files: + for regex in regexes: + matched = regex.findall(pr_file) + if len(matched) > 0: + item = matched[0] + item = item.replace("quantizers/quantizer_", "quantization/") + # TODO: for files in `quantizers`, the processed item above may not exist. Try using a fuzzy matching + if item in repo_content: + jobs_to_run.append(item) + break + jobs_to_run = sorted(set(jobs_to_run)) + + return jobs_to_run + + +def parse_message(message: str) -> str: + """ + Parses a GitHub pull request's comment to find the models specified in it to run slow CI. + + Args: + message (`str`): The body of a GitHub pull request's comment. + + Returns: + `str`: The substring in `message` after `run-slow`, run_slow` or run slow`. If no such prefix is found, the + empty string is returned. + """ + if message is None: + return "" + + message = message.strip().lower() + + # run-slow: model_1, model_2, quantization_1, quantization_2 + if not message.startswith(("run-slow", "run_slow", "run slow")): + return "" + message = message[len("run slow") :] + # remove leading `:` + while message.strip().startswith(":"): + message = message.strip()[1:] + + return message + + +def get_jobs(message: str): + models = parse_message(message) + return models.replace(",", " ").split() + + +def check_name(model_name: str): + allowed = string.ascii_letters + string.digits + "_" + return not (model_name.startswith("_") or model_name.endswith("_")) and all(c in allowed for c in model_name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--message", type=str, default="", help="The content of a comment.") + parser.add_argument("--quantization", action="store_true", help="If we collect quantization tests") + args = parser.parse_args() + + # The files are prepared by the caller (using GitHub api). + # We can also use the following api to get the information if we don't have them before calling this script. + # url = f"https://api.github.com/repos/OWNER/REPO/contents/PATH?ref={pr_sha}" + # (we avoid to checkout the repository using `actions/checkout` to reduce the run time, but mostly to avoid the potential security issue as much as possible) + repo_content = [] + for filename in ["tests_dir.txt", "tests_models_dir.txt", "tests_quantization_dir.txt"]: + with open(filename) as fp: + data = json.load(fp) + data = [item["path"][len("tests/") :] for item in data if item["type"] == "dir"] + repo_content.extend(data) + + # These don't have the prefix `models/` or `quantization/`, so we need to add them. + if args.message: + specified_jobs = get_jobs(args.message) + specified_jobs = [job for job in specified_jobs if check_name(job)] + + # Add prefix (`models/` or `quantization`) + jobs_to_run = [] + for job in specified_jobs: + if not args.quantization: + if f"models/{job}" in repo_content: + jobs_to_run.append(f"models/{job}") + elif job in repo_content and job != "quantization": + jobs_to_run.append(job) + elif f"quantization/{job}" in repo_content: + jobs_to_run.append(f"quantization/{job}") + + print(sorted(set(jobs_to_run))) + + else: + # Compute (from the added/modified files) the directories under `tests/`, `tests/models/` and `tests/quantization`to run tests. + # These are already with the prefix `models/` or `quantization/`, so we don't need to add them. + jobs_to_run = get_jobs_to_run() + jobs_to_run = [x.replace("models/", "").replace("quantization/", "") for x in jobs_to_run] + jobs_to_run = [job for job in jobs_to_run if check_name(job)] + + if len(jobs_to_run) > MAX_NUM_JOBS_TO_SUGGEST: + jobs_to_run = jobs_to_run[:MAX_NUM_JOBS_TO_SUGGEST] + + suggestion = f"{', '.join(jobs_to_run)}" + + print(suggestion)