diff --git a/.circleci/config.yml b/.circleci/config.yml index 67f294bc971..ef49dc7e023 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -43,16 +43,6 @@ jobs: parallelism: 1 steps: - checkout - - run: python3 utils/extract_pr_number_from_circleci.py > pr_number.txt - - run: echo $(cat pr_number.txt) - - run: if [[ "$(cat pr_number.txt)" == "" && "$CIRCLE_BRANCH" != "main" && "$CIRCLE_BRANCH" != *-release ]]; then echo "Not a PR, not the main branch and not a release branch, skip test!"; circleci-agent step halt; fi - - run: 'curl -L -H "Accept: application/vnd.github+json" -H "X-GitHub-Api-Version: 2022-11-28" https://api.github.com/repos/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME/pulls/$(cat pr_number.txt) >> github.txt' - - run: cat github.txt - - run: (python3 -c 'import json; from datetime import datetime; fp = open("github.txt"); data = json.load(fp); fp.close(); f = "%Y-%m-%dT%H:%M:%SZ"; created = datetime.strptime(data["created_at"], f); updated = datetime.strptime(data["updated_at"], f); s = (updated - created).total_seconds(); print(int(s))' || true) > elapsed.txt - - run: if [ "$(cat elapsed.txt)" == "" ]; then echo 60 > elapsed.txt; fi - - run: cat elapsed.txt - - run: if [ "$(cat elapsed.txt)" -lt "30" ]; then echo "PR is just opened, wait some actions from GitHub"; sleep 30; fi - - run: 'if grep -q "\"draft\": true," github.txt; then echo "draft mode, skip test!"; circleci-agent step halt; fi' - run: uv pip install -U -e . - run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV" - run: mkdir -p test_preparation @@ -122,8 +112,6 @@ jobs: - run: name: "Retrieve Artifact Paths" - env: - CIRCLE_TOKEN: ${{ secrets.CI_ARTIFACT_TOKEN }} command: | project_slug="gh/${CIRCLE_PROJECT_USERNAME}/${CIRCLE_PROJECT_REPONAME}" job_number=${CIRCLE_BUILD_NUM} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index c2f61c45354..4ab3f239279 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -2,6 +2,15 @@ name: Build PR Documentation on: pull_request: + workflow_call: + inputs: + pr_number: + type: string + required: true + commit_sha: + type: string + required: true + concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -9,9 +18,9 @@ concurrency: jobs: build: - uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@6e2eb04a2604817c97be03786efa494fe3acae90 with: - commit_sha: ${{ github.event.pull_request.head.sha }} - pr_number: ${{ github.event.number }} + commit_sha: ${{ inputs.commit_sha || github.event.pull_request.head.sha }} + pr_number: ${{ inputs.pr_number || github.event.number }} package: transformers languages: en diff --git a/.github/workflows/check_failed_model_tests.yml b/.github/workflows/check_failed_tests.yml similarity index 76% rename from .github/workflows/check_failed_model_tests.yml rename to .github/workflows/check_failed_tests.yml index 653b50e4cf6..478f9d0ae2a 100644 --- a/.github/workflows/check_failed_model_tests.yml +++ b/.github/workflows/check_failed_tests.yml @@ -9,6 +9,18 @@ on: start_sha: required: true type: string + job: + required: true + type: string + slack_report_channel: + required: true + type: string + ci_event: + required: true + type: string + report_repo_id: + required: true + type: string env: @@ -26,7 +38,7 @@ env: jobs: - run_models_gpu: + check_new_failures: name: " " runs-on: group: aws-g4dn-4xlarge-cache @@ -36,17 +48,17 @@ jobs: steps: - uses: actions/download-artifact@v4 with: - name: ci_results_run_models_gpu - path: /transformers/ci_results_run_models_gpu + name: ci_results_${{ inputs.job }} + path: /transformers/ci_results_${{ inputs.job }} - name: Check file working-directory: /transformers run: | - if [ -f ci_results_run_models_gpu/new_model_failures.json ]; then - echo "`ci_results_run_models_gpu/new_model_failures.json` exists, continue ..." + if [ -f ci_results_${{ inputs.job }}/new_failures.json ]; then + echo "`ci_results_${{ inputs.job }}/new_failures.json` exists, continue ..." echo "process=true" >> $GITHUB_ENV else - echo "`ci_results_run_models_gpu/new_model_failures.json` doesn't exist, abort." + echo "`ci_results_${{ inputs.job }}/new_failures.json` doesn't exist, abort." echo "process=false" >> $GITHUB_ENV fi @@ -112,14 +124,14 @@ jobs: - name: Check failed tests working-directory: /transformers if: ${{ env.process == 'true' }} - run: python3 utils/check_bad_commit.py --start_commit ${{ inputs.start_sha }} --end_commit ${{ env.END_SHA }} --file ci_results_run_models_gpu/new_model_failures.json --output_file new_model_failures_with_bad_commit.json + run: python3 utils/check_bad_commit.py --start_commit ${{ inputs.start_sha }} --end_commit ${{ env.END_SHA }} --file ci_results_${{ inputs.job }}/new_failures.json --output_file new_failures_with_bad_commit.json - name: Show results working-directory: /transformers if: ${{ env.process == 'true' }} run: | - ls -l new_model_failures_with_bad_commit.json - cat new_model_failures_with_bad_commit.json + ls -l new_failures_with_bad_commit.json + cat new_failures_with_bad_commit.json - name: Checkout back working-directory: /transformers @@ -134,6 +146,8 @@ jobs: env: ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }} TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }} + JOB_NAME: ${{ inputs.job }} + REPORT_REPO_ID: ${{ inputs.report_repo_id }} run: | python3 utils/process_bad_commit_report.py @@ -144,6 +158,8 @@ jobs: env: ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }} TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }} + JOB_NAME: ${{ inputs.job }} + REPORT_REPO_ID: ${{ inputs.report_repo_id }} run: | { echo 'REPORT_TEXT<> "$GITHUB_ENV" + - name: Prepare Slack report title + working-directory: /transformers + if: ${{ env.process == 'true' }} + run: | + pip install slack_sdk + echo "title=$(python3 -c 'import sys; sys.path.append("utils"); from utils.notification_service import job_to_test_map; ci_event = "${{ inputs.ci_event }}"; job = "${{ inputs.job }}"; test_name = job_to_test_map[job]; title = f"New failed tests of {ci_event}" + ":" + f" {test_name}"; print(title)')" >> $GITHUB_ENV + - name: Send processed report if: ${{ env.process == 'true' && !endsWith(env.REPORT_TEXT, '{}') }} uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 with: # Slack channel id, channel name, or user id to post message. # See also: https://api.slack.com/methods/chat.postMessage#channels - channel-id: '#transformers-ci-feedback-tests' + channel-id: '#${{ inputs.slack_report_channel }}' # For posting a rich message using Block Kit payload: | { "blocks": [ + { + "type": "header", + "text": { + "type": "plain_text", + "text": "${{ env.title }}" + } + }, { "type": "section", "text": { diff --git a/.github/workflows/pr-style-bot.yml b/.github/workflows/pr-style-bot.yml index 9ca716ec50d..fdb76d8db3d 100644 --- a/.github/workflows/pr-style-bot.yml +++ b/.github/workflows/pr-style-bot.yml @@ -11,9 +11,24 @@ permissions: jobs: style: - uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main + uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@639ee721e149a281fe726a50a2cc1354b48bc463 with: python_quality_dependencies: "[quality]" style_command_type: "default" secrets: bot_token: ${{ secrets.GITHUB_TOKEN }} + + check-outputs: + runs-on: ubuntu-latest + needs: style + steps: + - run: echo ${{ needs.style.outputs.pr_number }} + - run: echo ${{ needs.style.outputs.new_commit_sha }} + + trigger: + needs: style + if: needs.style.outputs.new_commit_sha != '' + uses: "./.github/workflows/build_pr_documentation.yml" + with: + pr_number: ${{ needs.style.outputs.pr_number }} + commit_sha: ${{ needs.style.outputs.new_commit_sha }} diff --git a/.github/workflows/self-comment-ci.yml b/.github/workflows/self-comment-ci.yml index dc4b394e2d3..f9c25abd4d4 100644 --- a/.github/workflows/self-comment-ci.yml +++ b/.github/workflows/self-comment-ci.yml @@ -29,7 +29,7 @@ jobs: runs-on: ubuntu-22.04 name: Get PR number # For security: only allow team members to run - if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb", "MekkCyber", "manueldeprada"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }} + if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb", "MekkCyber", "manueldeprada", "vasqu"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }} outputs: PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }} steps: diff --git a/.github/workflows/self-scheduled-amd-mi210-caller.yml b/.github/workflows/self-scheduled-amd-mi210-caller.yml deleted file mode 100644 index 6109faca009..00000000000 --- a/.github/workflows/self-scheduled-amd-mi210-caller.yml +++ /dev/null @@ -1,55 +0,0 @@ -name: Self-hosted runner (AMD mi210 scheduled CI caller) - -on: - workflow_run: - workflows: ["Self-hosted runner (AMD scheduled CI caller)"] - branches: ["main"] - types: [completed] - push: - branches: - - run_amd_scheduled_ci_caller* - -jobs: - model-ci: - name: Model CI - uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled.yaml@main - with: - job: run_models_gpu - slack_report_channel: "#transformers-ci-daily-amd" - runner: mi210 - docker: huggingface/transformers-pytorch-amd-gpu - ci_event: Scheduled CI (AMD) - mi210 - secrets: inherit - - torch-pipeline: - name: Torch pipeline CI - uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled.yaml@main - with: - job: run_pipelines_torch_gpu - slack_report_channel: "#transformers-ci-daily-amd" - runner: mi210 - docker: huggingface/transformers-pytorch-amd-gpu - ci_event: Scheduled CI (AMD) - mi210 - secrets: inherit - - example-ci: - name: Example CI - uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled.yaml@main - with: - job: run_examples_gpu - slack_report_channel: "#transformers-ci-daily-amd" - runner: mi210 - docker: huggingface/transformers-pytorch-amd-gpu - ci_event: Scheduled CI (AMD) - mi210 - secrets: inherit - - deepspeed-ci: - name: DeepSpeed CI - uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled.yaml@main - with: - job: run_torch_cuda_extensions_gpu - slack_report_channel: "#transformers-ci-daily-amd" - runner: mi210 - docker: huggingface/transformers-pytorch-deepspeed-amd-gpu - ci_event: Scheduled CI (AMD) - mi210 - secrets: inherit diff --git a/.github/workflows/self-scheduled-amd-mi250-caller.yml b/.github/workflows/self-scheduled-amd-mi250-caller.yml index 4c6284a78cd..581d9137709 100644 --- a/.github/workflows/self-scheduled-amd-mi250-caller.yml +++ b/.github/workflows/self-scheduled-amd-mi250-caller.yml @@ -15,10 +15,11 @@ jobs: uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled.yaml@main with: job: run_models_gpu - slack_report_channel: "#amd-hf-ci" + slack_report_channel: "#transformers-ci-daily-amd" runner: mi250 docker: huggingface/transformers-pytorch-amd-gpu ci_event: Scheduled CI (AMD) - mi250 + report_repo_id: optimum-amd/transformers_daily_ci secrets: inherit torch-pipeline: @@ -26,10 +27,11 @@ jobs: uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled.yaml@main with: job: run_pipelines_torch_gpu - slack_report_channel: "#amd-hf-ci" + slack_report_channel: "#transformers-ci-daily-amd" runner: mi250 docker: huggingface/transformers-pytorch-amd-gpu ci_event: Scheduled CI (AMD) - mi250 + report_repo_id: optimum-amd/transformers_daily_ci secrets: inherit example-ci: @@ -37,10 +39,11 @@ jobs: uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled.yaml@main with: job: run_examples_gpu - slack_report_channel: "#amd-hf-ci" + slack_report_channel: "#transformers-ci-daily-amd" runner: mi250 docker: huggingface/transformers-pytorch-amd-gpu ci_event: Scheduled CI (AMD) - mi250 + report_repo_id: optimum-amd/transformers_daily_ci secrets: inherit deepspeed-ci: @@ -48,8 +51,9 @@ jobs: uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled.yaml@main with: job: run_torch_cuda_extensions_gpu - slack_report_channel: "#amd-hf-ci" + slack_report_channel: "#transformers-ci-daily-amd" runner: mi250 docker: huggingface/transformers-pytorch-deepspeed-amd-gpu ci_event: Scheduled CI (AMD) - mi250 + report_repo_id: optimum-amd/transformers_daily_ci secrets: inherit diff --git a/.github/workflows/self-scheduled-amd-mi300-caller.yml b/.github/workflows/self-scheduled-amd-mi300-caller.yml new file mode 100644 index 00000000000..d5310fb3072 --- /dev/null +++ b/.github/workflows/self-scheduled-amd-mi300-caller.yml @@ -0,0 +1,63 @@ +name: Self-hosted runner scale set (AMD mi300 scheduled CI caller) + +# Note: For every job in this workflow, the name of the runner scale set is finalized in the runner yaml i.e. huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml +# For example, 1gpu scale set: amd-mi300-ci-1gpu +# 2gpu scale set: amd-mi300-ci-2gpu + +on: + workflow_run: + workflows: ["Self-hosted runner (AMD scheduled CI caller)"] + branches: ["main"] + types: [completed] + push: + branches: + - run_amd_scheduled_ci_caller* + +jobs: + model-ci: + name: Model CI + uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml@main + with: + job: run_models_gpu + slack_report_channel: "#amd-hf-ci" + runner_scale_set: amd-mi300-ci + docker: huggingface/transformers-pytorch-amd-gpu + ci_event: Scheduled CI (AMD) - mi300 + report_repo_id: optimum-amd/transformers_daily_ci + secrets: inherit + + torch-pipeline: + name: Torch pipeline CI + uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml@main + with: + job: run_pipelines_torch_gpu + slack_report_channel: "#amd-hf-ci" + runner_scale_set: amd-mi300-ci + docker: huggingface/transformers-pytorch-amd-gpu + ci_event: Scheduled CI (AMD) - mi300 + report_repo_id: optimum-amd/transformers_daily_ci + secrets: inherit + + example-ci: + name: Example CI + uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml@main + with: + job: run_examples_gpu + slack_report_channel: "#amd-hf-ci" + runner_scale_set: amd-mi300-ci + docker: huggingface/transformers-pytorch-amd-gpu + ci_event: Scheduled CI (AMD) - mi300 + report_repo_id: optimum-amd/transformers_daily_ci + secrets: inherit + + deepspeed-ci: + name: DeepSpeed CI + uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml@main + with: + job: run_torch_cuda_extensions_gpu + slack_report_channel: "#amd-hf-ci" + runner_scale_set: amd-mi300-ci + docker: huggingface/transformers-pytorch-deepspeed-amd-gpu + ci_event: Scheduled CI (AMD) - mi300 + report_repo_id: optimum-amd/transformers_daily_ci + secrets: inherit diff --git a/.github/workflows/self-scheduled-caller.yml b/.github/workflows/self-scheduled-caller.yml index 77b33850fe4..f48d357cd5d 100644 --- a/.github/workflows/self-scheduled-caller.yml +++ b/.github/workflows/self-scheduled-caller.yml @@ -54,6 +54,7 @@ jobs: runner: daily-ci docker: huggingface/transformers-all-latest-gpu ci_event: Daily CI + report_repo_id: hf-internal-testing/transformers_daily_ci secrets: inherit torch-pipeline: @@ -65,6 +66,7 @@ jobs: runner: daily-ci docker: huggingface/transformers-pytorch-gpu ci_event: Daily CI + report_repo_id: hf-internal-testing/transformers_daily_ci secrets: inherit tf-pipeline: @@ -76,6 +78,7 @@ jobs: runner: daily-ci docker: huggingface/transformers-tensorflow-gpu ci_event: Daily CI + report_repo_id: hf-internal-testing/transformers_daily_ci secrets: inherit example-ci: @@ -87,6 +90,7 @@ jobs: runner: daily-ci docker: huggingface/transformers-all-latest-gpu ci_event: Daily CI + report_repo_id: hf-internal-testing/transformers_daily_ci secrets: inherit trainer-fsdp-ci: @@ -98,6 +102,7 @@ jobs: runner: daily-ci docker: huggingface/transformers-all-latest-gpu ci_event: Daily CI + report_repo_id: hf-internal-testing/transformers_daily_ci secrets: inherit deepspeed-ci: @@ -110,6 +115,7 @@ jobs: docker: huggingface/transformers-pytorch-deepspeed-latest-gpu ci_event: Daily CI working-directory-prefix: /workspace + report_repo_id: hf-internal-testing/transformers_daily_ci secrets: inherit quantization-ci: @@ -121,4 +127,5 @@ jobs: runner: daily-ci docker: huggingface/transformers-quantization-latest-gpu ci_event: Daily CI + report_repo_id: hf-internal-testing/transformers_daily_ci secrets: inherit diff --git a/.github/workflows/self-scheduled.yml b/.github/workflows/self-scheduled.yml index 1198148fd63..36c113190ca 100644 --- a/.github/workflows/self-scheduled.yml +++ b/.github/workflows/self-scheduled.yml @@ -28,6 +28,10 @@ on: default: '' required: false type: string + report_repo_id: + required: true + type: string + env: HF_HOME: /mnt/cache @@ -584,15 +588,21 @@ jobs: folder_slices: ${{ needs.setup.outputs.folder_slices }} quantization_matrix: ${{ needs.setup.outputs.quantization_matrix }} ci_event: ${{ inputs.ci_event }} + report_repo_id: ${{ inputs.report_repo_id }} secrets: inherit - check_new_model_failures: - if: ${{ always() && inputs.ci_event == 'Daily CI' && inputs.job == 'run_models_gpu' && needs.send_results.result == 'success' }} - name: Check new model failures + check_new_failures: + if: ${{ always() && inputs.ci_event == 'Daily CI' && needs.send_results.result == 'success' }} + name: Check new failures needs: send_results - uses: ./.github/workflows/check_failed_model_tests.yml + uses: ./.github/workflows/check_failed_tests.yml with: docker: ${{ inputs.docker }} start_sha: ${{ github.sha }} + job: ${{ inputs.job }} + slack_report_channel: ${{ inputs.slack_report_channel }} + ci_event: ${{ inputs.ci_event }} + report_repo_id: ${{ inputs.report_repo_id }} + secrets: inherit diff --git a/.github/workflows/slack-report.yml b/.github/workflows/slack-report.yml index bea113ca031..5ef74946964 100644 --- a/.github/workflows/slack-report.yml +++ b/.github/workflows/slack-report.yml @@ -21,6 +21,9 @@ on: ci_event: required: true type: string + report_repo_id: + required: true + type: string env: TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }} @@ -55,7 +58,7 @@ jobs: fi - name: Send message to Slack - if: ${{ inputs.job != 'run_quantization_torch_gpu' }} + shell: bash env: CI_SLACK_BOT_TOKEN: ${{ secrets.CI_SLACK_BOT_TOKEN }} CI_SLACK_CHANNEL_ID: ${{ secrets.CI_SLACK_CHANNEL_ID }} @@ -67,6 +70,7 @@ jobs: CI_SHA: ${{ github.sha }} CI_TEST_JOB: ${{ inputs.job }} SETUP_STATUS: ${{ inputs.setup_status }} + REPORT_REPO_ID: ${{ inputs.report_repo_id }} # We pass `needs.setup.outputs.matrix` as the argument. A processing in `notification_service.py` to change # `models/bert` to `models_bert` is required, as the artifact names use `_` instead of `/`. # For a job that doesn't depend on (i.e. `needs`) `setup`, the value for `inputs.folder_slices` would be an @@ -75,7 +79,11 @@ jobs: pip install huggingface_hub pip install slack_sdk pip show slack_sdk - python utils/notification_service.py "${{ inputs.folder_slices }}" + if [ "${{ inputs.quantization_matrix }}" != "" ]; then + python utils/notification_service.py "${{ inputs.quantization_matrix }}" + else + python utils/notification_service.py "${{ inputs.folder_slices }}" + fi # Upload complete failure tables, as they might be big and only truncated versions could be sent to Slack. - name: Failure table artifacts @@ -83,31 +91,3 @@ jobs: with: name: ci_results_${{ inputs.job }} path: ci_results_${{ inputs.job }} - - - uses: actions/checkout@v4 - - uses: actions/download-artifact@v4 - - name: Send message to Slack for quantization workflow - if: ${{ inputs.job == 'run_quantization_torch_gpu' }} - env: - CI_SLACK_BOT_TOKEN: ${{ secrets.CI_SLACK_BOT_TOKEN }} - ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }} - SLACK_REPORT_CHANNEL: ${{ inputs.slack_report_channel }} - CI_EVENT: ${{ inputs.ci_event }} - CI_SHA: ${{ github.sha }} - CI_TEST_JOB: ${{ inputs.job }} - SETUP_STATUS: ${{ inputs.setup_status }} - # We pass `needs.setup.outputs.quantization_matrix` as the argument. A processing in `notification_service_quantization.py` to change - # `quantization/bnb` to `quantization_bnb` is required, as the artifact names use `_` instead of `/`. - run: | - pip install huggingface_hub - pip install slack_sdk - pip show slack_sdk - python utils/notification_service_quantization.py "${{ inputs.quantization_matrix }}" - - # Upload complete failure tables, as they might be big and only truncated versions could be sent to Slack. - - name: Failure table artifacts - if: ${{ inputs.job == 'run_quantization_torch_gpu' }} - uses: actions/upload-artifact@v4 - with: - name: ci_results_${{ inputs.job }} - path: ci_results_${{ inputs.job }} diff --git a/docker/transformers-pytorch-amd-gpu/Dockerfile b/docker/transformers-pytorch-amd-gpu/Dockerfile index a71043dc821..7e51233779b 100644 --- a/docker/transformers-pytorch-amd-gpu/Dockerfile +++ b/docker/transformers-pytorch-amd-gpu/Dockerfile @@ -1,4 +1,4 @@ -FROM rocm/dev-ubuntu-22.04:6.2.4 +FROM rocm/pytorch:rocm6.4_ubuntu22.04_py3.10_pytorch_release_2.6.0 LABEL maintainer="Hugging Face" ARG DEBIAN_FRONTEND=noninteractive @@ -11,9 +11,6 @@ RUN apt update && \ RUN git lfs install RUN python3 -m pip install --no-cache-dir --upgrade pip numpy - -RUN python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4 - RUN python3 -m pip install --no-cache-dir --upgrade importlib-metadata setuptools ninja git+https://github.com/facebookresearch/detectron2.git pytesseract "itsdangerous<2.1.0" ARG REF=main @@ -33,3 +30,6 @@ RUN cd transformers && python3 setup.py develop # Remove nvml and nvidia-ml-py as it is not compatible with ROCm. apex is not tested on NVIDIA either. RUN python3 -m pip uninstall py3nvml pynvml nvidia-ml-py apex -y + +# `kernels` may causes many failing tests +RUN python3 -m pip uninstall -y kernels \ No newline at end of file diff --git a/docker/transformers-pytorch-deepspeed-amd-gpu/Dockerfile b/docker/transformers-pytorch-deepspeed-amd-gpu/Dockerfile index f70b1549410..e38345ca0f7 100644 --- a/docker/transformers-pytorch-deepspeed-amd-gpu/Dockerfile +++ b/docker/transformers-pytorch-deepspeed-amd-gpu/Dockerfile @@ -48,3 +48,6 @@ RUN python3 -c "from deepspeed.launcher.runner import main" # Remove nvml as it is not compatible with ROCm RUN python3 -m pip uninstall py3nvml pynvml nvidia-ml-py apex -y + +# `kernels` may causes many failing tests +RUN python3 -m pip uninstall -y kernels diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 700e218d8be..d3e7c9438be 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -76,12 +76,12 @@ title: Prompt engineering - local: llm_optims title: Optimizing inference + - local: cache_explanation + title: Caching - local: kv_cache title: KV cache strategies - local: serving title: Serving - - local: cache_explanation - title: Caching - local: llm_tutorial_optimization title: Getting the most out of LLMs - local: perplexity @@ -388,7 +388,7 @@ - local: model_doc/bert-japanese title: BertJapanese - local: model_doc/bertweet - title: Bertweet + title: BERTweet - local: model_doc/big_bird title: BigBird - local: model_doc/bigbird_pegasus @@ -544,7 +544,7 @@ - local: model_doc/mamba title: Mamba - local: model_doc/mamba2 - title: mamba2 + title: Mamba2 - local: model_doc/marian title: MarianMT - local: model_doc/markuplm @@ -1123,4 +1123,9 @@ - local: internal/time_series_utils title: Utilities for Time Series title: Internal helpers + - sections: + - local: reference/environment_variables + title: Environment Variables + title: Reference title: API + diff --git a/docs/source/en/cache_explanation.md b/docs/source/en/cache_explanation.md index 59496e4298f..0ccf612d217 100644 --- a/docs/source/en/cache_explanation.md +++ b/docs/source/en/cache_explanation.md @@ -15,8 +15,7 @@ rendered properly in your Markdown viewer. --> # Caching - -Imagine you’re having a conversation with someone, and instead of remembering what they previously said, they have to start from scratch every time you respond. This would be slow and inefficient, right? +Imagine you're having a conversation with someone, and instead of remembering what they previously said, they have to start from scratch every time you respond. This would be slow and inefficient, right? You can extend this analogy to transformer models. Autoregressive model generation can be slow because it makes a prediction one token at a time. Each new prediction is dependent on all the previous context. @@ -29,8 +28,50 @@ A key-value (KV) cache eliminates this inefficiency by storing kv pairs derived > [!WARNING] > Caching should only be used for **inference**. It may cause unexpected errors if it's enabled during training. +To better understand how and why caching works, let's take a closer look at the structure of the attention matrices. + +## Attention matrices + +The **scaled dot-product attention** is calculated as shown below for a batch of size `b`, number of attention heads `h`, sequence length so far `T`, and dimension per attention head `d_head`. + +$$ +\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \times \text{mask} \right) V +$$ + +The query (`Q`), key (`K`), and value (`V`) matrices are projections from the input embeddings of shape `(b, h, T, d_head)`. + +For causal attention, the mask prevents the model from attending to future tokens. Once a token is processed, its representation never changes with respect to future tokens, which means \\( K_{\text{past}} \\) and \\( V_{\text{past}} \\) can be cached and reused to compute the last token's representation. + +$$ +\text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}]) +$$ + +At inference time, you only need the last token's query to compute the representation \\( x_t \\) that predicts the next token \\( t+1 \\). At each step, the new key and value vectors are **stored** in the cache and **appended** to the past keys and values. + +$$ +K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t) +$$ + +Attention is calculated independently in each layer of the model, and caching is done on a per-layer basis. + +Refer to the table below to compare how caching improves efficiency. + +| without caching | with caching | | | | +|---|---|---|---|---| +| for each step, recompute all previous `K` and `V` | for each step, only compute current `K` and `V` | | | | +| attention cost per step is **quadratic** with sequence length | attention cost per step is **linear** with sequence length (memory grows linearly, but compute/token remains low) | | | | + + + ## Cache class +A basic KV cache interface takes a key and value tensor for the current token and returns the updated `K` and `V` tensors. This is internally managed by a model's `forward` method. + +```py +new_K, new_V = cache.update(k_t, v_t, layer_idx) +attn_output = attn_layer_idx_fn(q_t, new_K, new_V) +``` + When you use Transformers' [`Cache`] class, the self-attention module performs several critical steps to integrate past and present information. 1. The attention module concatenates current kv pairs with past kv pairs stored in the cache. This creates attentions weights with the shape `(new_tokens_length, past_kv_length + new_tokens_length)`. The current and past kv pairs are essentially combined to compute the attention scores, ensuring a model is aware of previous context and the current input. @@ -39,6 +80,27 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s 3. It is also important to be aware of the `cache_position`. This is important if you want to reuse a prefilled [`Cache`] with the `forward` method because you have to pass a valid `cache_position` value. This indicates the input positions in a sequence. `cache_position` is unaffected by padding, and it always adds one more position for each token. For example, if a kv cache contains 10 tokens - regardless of pad tokens - the cache position for the next token should be `torch.tensor([10])`. +## Cache storage implementation + +The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`]. + + +In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`. +- `key_cache`: A list of tensors, one for each layer. +- `value_cache`: A list of tensors, one for each layer. + +When new tokens are processed: + +1. For each layer, the new key and value states are concatenated with the existing cache. +```py +self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) +self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) +``` + +2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token. + +3. The cache maintains a count of seen tokens through `self._seen_tokens`. This is updated when the first layer processes a new token. + The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token. ```py @@ -72,10 +134,14 @@ for _ in range(max_new_tokens): print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]) "[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA," ``` - ## Legacy cache format -Before the [`Cache`] class, the cache used to be stored as a tuple of tuples of tensors. This format has is dynamic because it grows as text is generated, similar to [`DynamicCache`]. +Before the [`Cache`] class, the cache used to be stored as a tuple of tuples of tensors. This format is dynamic because it grows as text is generated, similar to [`DynamicCache`]. + +The legacy format is essentially the same data structure but organized differently. +- It's a tuple of tuples, where each inner tuple contains the key and value tensors for a layer. +- The tensors have the same shape `[batch_size, num_heads, seq_len, head_dim]`. +- The format is less flexible and doesn't support features like quantization or offloading. If your project depends on this legacy format, you can convert between [`DynamicCache`] and a tuple of tuples as shown below with the [`~DynamicCache.from_legacy_cache`] and [`DynamicCache.to_legacy_cache`] functions. This is helpful if you have custom logic for manipulating a cache in a specific format. diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index c6cb322e882..9e2cbf485c4 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -327,7 +327,6 @@ We enable custom decoding methods through model repositories, assuming a specifi If a model repository holds a custom decoding method, the easiest way to try it out is to load the model and generate with it: - ```py from transformers import AutoModelForCausalLM, AutoTokenizer @@ -430,7 +429,7 @@ This is the core of your decoding method. It *must* contain a method named `gene > [!WARNING] > `generate.py` must be placed in a folder named `custom_generate`, and not at the root level of the repository. The file paths for this feature are hardcoded. -Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method. +Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method, with the exception of the arguments used to trigger the custom generation (`trust_remote_code` and `custom_generate`). This means your `generate` can have a mix of original and custom arguments (as well as a different output type) as shown below. diff --git a/docs/source/en/internal/model_debugging_utils.md b/docs/source/en/internal/model_debugging_utils.md index 6d30668c634..69f622ae109 100644 --- a/docs/source/en/internal/model_debugging_utils.md +++ b/docs/source/en/internal/model_debugging_utils.md @@ -16,7 +16,8 @@ rendered properly in your Markdown viewer. # Model debugging toolboxes -This page lists all the debugging and model adding tools used by the library, as well as the utility functions it provides for it. +This page lists all the debugging and model adding tools used by the library, as well as the utility functions it +provides for it. Most of those are only useful if you are adding new models in the library. @@ -26,13 +27,14 @@ Most of those are only useful if you are adding new models in the library. ### Model addition debugger - context manager for model adders -This context manager is a power user tool intended for model adders. -It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json. -To note, this context manager enforces `torch.no_grad()`. +This context manager is a power user tool intended for model adders. It tracks all forward calls within a model forward +and logs a slice of each input and output on a nested JSON. To note, this context manager enforces `torch.no_grad()`. ### Rationale -Because when porting models to transformers, even from python to python, model adders often have to do a lot of manual operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some time. +When porting models to transformers, even from python to python, model adders often have to do a lot of manual +operations, involving saving and loading tensors, comparing dtypes, etc. This small tool can hopefully shave off some +time. ### Usage @@ -62,10 +64,10 @@ inputs = processor(text=prompt, images=random_image, return_tensors="pt") # call forward method (not .generate!) with model_addition_debugger_context( - model, - debug_path="optional_path_to_your_directory", - do_prune_layers=False # This will output ALL the layers of a model. - ): + model, + debug_path="optional_path_to_your_directory", + do_prune_layers=False # This will output ALL the layers of a model. +): output = model.forward(**inputs) ``` @@ -73,8 +75,8 @@ with model_addition_debugger_context( ### Reading results -The debugger generates two files from the forward call, both with the same base name, -but ending either with `_SUMMARY.json` or with `_FULL_TENSORS.json`. +The debugger generates two files from the forward call, both with the same base name, but ending either with +`_SUMMARY.json` or with `_FULL_TENSORS.json`. The first one will contain a summary of each module's _input_ and _output_ tensor values and shapes. @@ -142,8 +144,8 @@ The first one will contain a summary of each module's _input_ and _output_ tenso { ... and so on ``` -The `_FULL_TENSORS.json` file will display a full view of all tensors, which is useful -for comparing two files. +The `_FULL_TENSORS.json` file will display a full view of all tensors, which is useful for comparing two files. + ```json "pixel_values": { "shape": "torch.Size([1, 5, 576, 588])", @@ -196,9 +198,38 @@ for comparing two files. }, ``` +#### Saving tensors to disk + +Some model adders may benefit from logging full tensor values to disk to support, for example, numerical analysis +across implementations. + +Set `use_repr=False` to write tensors to disk using [SafeTensors](https://huggingface.co/docs/safetensors/en/index). + +```python +with model_addition_debugger_context( + model, + debug_path="optional_path_to_your_directory", + do_prune_layers=False, + use_repr=False, # Defaults to True +): + output = model.forward(**inputs) +``` + +When using `use_repr=False`, tensors are written to the same disk location as the `_SUMMARY.json` and +`_FULL_TENSORS.json` files. The `value` property of entries in the `_FULL_TENSORS.json` file will contain a relative +path reference to the associated `.safetensors` file. Each tensor is written to its own file as the `data` property of +the state dictionary. File names are constructed using the `module_path` as a prefix with a few possible postfixes that +are built recursively. + +* Module inputs are denoted with the `_inputs` and outputs by `_outputs`. +* `list` and `tuple` instances, such as `args` or function return values, will be postfixed with `_{index}`. +* `dict` instances will be postfixed with `_{key}`. + ### Comparing between implementations -Once the forward passes of two models have been traced by the debugger, one can compare the `json` output files. See below: we can see slight differences between these two implementations' key projection layer. Inputs are mostly identical, but not quite. Looking through the file differences makes it easier to pinpoint which layer is wrong. +Once the forward passes of two models have been traced by the debugger, one can compare the `json` output files. See +below: we can see slight differences between these two implementations' key projection layer. Inputs are mostly +identical, but not quite. Looking through the file differences makes it easier to pinpoint which layer is wrong. ![download-icon](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/files_difference_debugging.png) @@ -206,8 +237,13 @@ Once the forward passes of two models have been traced by the debugger, one can ### Limitations and scope -This feature will only work for torch-based models, and would require more work and case-by-case approach for say `jax`-based models that are usually compiled. Models relying heavily on external kernel calls may work, but trace will probably miss some things. Regardless, any python implementation that aims at mimicking another implementation can be traced once instead of reran N times with breakpoints. +This feature will only work for torch-based models, and would require more work and case-by-case approach for say +`jax`-based models that are usually compiled. Models relying heavily on external kernel calls may work, but trace will +probably miss some things. Regardless, any python implementation that aims at mimicking another implementation can be +traced once instead of reran N times with breakpoints. -If you pass `do_prune_layers=False` to your model debugger, ALL the layers will be outputted to `json`. Else, only the first and last layer will be shown. This is useful when some layers (typically cross-attention) appear only after N layers. +If you pass `do_prune_layers=False` to your model debugger, ALL the layers will be outputted to `json`. Else, only the +first and last layer will be shown. This is useful when some layers (typically cross-attention) appear only after N +layers. [[autodoc]] model_addition_debugger_context diff --git a/docs/source/en/llm_tutorial.md b/docs/source/en/llm_tutorial.md index a191cdb4634..1283e8b6a4c 100644 --- a/docs/source/en/llm_tutorial.md +++ b/docs/source/en/llm_tutorial.md @@ -84,14 +84,17 @@ GenerationConfig { } ``` -You can customize [`~GenerationMixin.generate`] by overriding the parameters and values in [`GenerationConfig`]. Some of the most commonly adjusted parameters are [max_new_tokens](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.max_new_tokens), [num_beams](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.num_beams), [do_sample](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.do_sample), and [num_return_sequences](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.num_return_sequences). +You can customize [`~GenerationMixin.generate`] by overriding the parameters and values in [`GenerationConfig`]. See [this section below](#common-options) for commonly adjusted parameters. ```py # enable beam search sampling strategy model.generate(**inputs, num_beams=4, do_sample=True) ``` -[`~GenerationMixin.generate`] can also be extended with external libraries or custom code. The `logits_processor` parameter accepts custom [`LogitsProcessor`] instances for manipulating the next token probability distribution. `stopping_criteria` supports custom [`StoppingCriteria`] to stop text generation. Check out the [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo) for more examples of external [`~GenerationMixin.generate`]-compatible extensions. +[`~GenerationMixin.generate`] can also be extended with external libraries or custom code: +1. the `logits_processor` parameter accepts custom [`LogitsProcessor`] instances for manipulating the next token probability distribution; +2. the `stopping_criteria` parameters supports custom [`StoppingCriteria`] to stop text generation; +3. other custom generation methods can be loaded through the `custom_generate` flag ([docs](generation_strategies.md/#custom-decoding-methods)). Refer to the [Generation strategies](./generation_strategies) guide to learn more about search, sampling, and decoding strategies. diff --git a/docs/source/en/main_classes/video_processor.md b/docs/source/en/main_classes/video_processor.md index bdff30e9c50..4ff973d2ed2 100644 --- a/docs/source/en/main_classes/video_processor.md +++ b/docs/source/en/main_classes/video_processor.md @@ -21,7 +21,7 @@ A **Video Processor** is a utility responsible for preparing input features for The video processor extends the functionality of image processors by allowing Vision Large Language Models (VLMs) to handle videos with a distinct set of arguments compared to images. It serves as the bridge between raw video data and the model, ensuring that input features are optimized for the VLM. -When adding a new VLM or updating an existing one to enable distinct video preprocessing, saving and reloading the processor configuration will store the video related arguments in a dedicated file named `video_preprocessing_config.json`. Don't worry if you haven't upadted your VLM, the processor will try to load video related configurations from a file named `preprocessing_config.json`. +When adding a new VLM or updating an existing one to enable distinct video preprocessing, saving and reloading the processor configuration will store the video related arguments in a dedicated file named `video_preprocessing_config.json`. Don't worry if you haven't updated your VLM, the processor will try to load video related configurations from a file named `preprocessing_config.json`. ### Usage Example diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md index afe343228f2..adab8591e29 100644 --- a/docs/source/en/model_doc/auto.md +++ b/docs/source/en/model_doc/auto.md @@ -389,3 +389,9 @@ The following auto classes are available for the following multimodal tasks. ### AutoModelForImageTextToText [[autodoc]] AutoModelForImageTextToText + +## Time Series + +### AutoModelForTimeSeriesPrediction + +[[autodoc]] AutoModelForTimeSeriesPrediction diff --git a/docs/source/en/model_doc/bart.md b/docs/source/en/model_doc/bart.md index b24daa3e6e1..d269b391ccc 100644 --- a/docs/source/en/model_doc/bart.md +++ b/docs/source/en/model_doc/bart.md @@ -14,116 +14,87 @@ rendered properly in your Markdown viewer. --> -# BART -
-PyTorch -TensorFlow -Flax -FlashAttention -SDPA +
+
+ PyTorch + TensorFlow + Flax + FlashAttention + SDPA
-## Overview +# BART +[BART](https://huggingface.co/papers/1910.13461) is a sequence-to-sequence model that combines the pretraining objectives from BERT and GPT. It’s pretrained by corrupting text in different ways like deleting words, shuffling sentences, or masking tokens and learning how to fix it. The encoder encodes the corrupted document and the corrupted text is fixed by the decoder. As it learns to recover the original text, BART gets really good at both understanding and generating language. -The Bart model was proposed in [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, -Translation, and Comprehension](https://arxiv.org/abs/1910.13461) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan -Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019. +You can find all the original BART checkpoints under the [AI at Meta](https://huggingface.co/facebook?search_models=bart) organization. -According to the abstract, +The example below demonstrates how to predict the `[MASK]` token with [`Pipeline`], [`AutoModel`], and from the command line. -- Bart uses a standard seq2seq/machine translation architecture with a bidirectional encoder (like BERT) and a - left-to-right decoder (like GPT). -- The pretraining task involves randomly shuffling the order of the original sentences and a novel in-filling scheme, - where spans of text are replaced with a single mask token. -- BART is particularly effective when fine tuned for text generation but also works well for comprehension tasks. It - matches the performance of RoBERTa with comparable training resources on GLUE and SQuAD, achieves new - state-of-the-art results on a range of abstractive dialogue, question answering, and summarization tasks, with gains - of up to 6 ROUGE. + + -This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The authors' code can be found [here](https://github.com/pytorch/fairseq/tree/master/examples/bart). +```py +import torch +from transformers import pipeline -## Usage tips: +pipeline = pipeline( + task="fill-mask", + model="facebook/bart-large", + torch_dtype=torch.float16, + device=0 +) +pipeline("Plants create through a process known as photosynthesis.") -- BART is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than - the left. -- Sequence-to-sequence model with an encoder and a decoder. Encoder is fed a corrupted version of the tokens, decoder is fed the original tokens (but has a mask to hide the future words like a regular transformers decoder). A composition of the following transformations are applied on the pretraining tasks for the encoder: +``` + + - * mask random tokens (like in BERT) - * delete random tokens - * mask a span of k tokens with a single mask token (a span of 0 tokens is an insertion of a mask token) - * permute sentences - * rotate the document to make it start at a specific token -- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")` +```py +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer -## Implementation Notes +tokenizer = AutoTokenizer.from_pretrained( + "facebook/bart-large", +) +model = AutoModelForMaskedLM.from_pretrained( + "facebook/bart-large", + torch_dtype=torch.float16, + device_map="auto", + attn_implementation="sdpa" +) +inputs = tokenizer("Plants create through a process known as photosynthesis.", return_tensors="pt").to("cuda") -- Bart doesn't use `token_type_ids` for sequence classification. Use [`BartTokenizer`] or - [`~BartTokenizer.encode`] to get the proper splitting. -- The forward pass of [`BartModel`] will create the `decoder_input_ids` if they are not passed. - This is different than some other modeling APIs. A typical use case of this feature is mask filling. -- Model predictions are intended to be identical to the original implementation when - `forced_bos_token_id=0`. This only works, however, if the string you pass to - [`fairseq.encode`] starts with a space. -- [`~generation.GenerationMixin.generate`] should be used for conditional generation tasks like - summarization, see the example in that docstrings. -- Models that load the *facebook/bart-large-cnn* weights will not have a `mask_token_id`, or be able to perform - mask-filling tasks. +with torch.no_grad(): + outputs = model(**inputs) + predictions = outputs.logits -## Mask Filling +masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1] +predicted_token_id = predictions[0, masked_index].argmax(dim=-1) +predicted_token = tokenizer.decode(predicted_token_id) -The `facebook/bart-base` and `facebook/bart-large` checkpoints can be used to fill multi-token masks. - -```python -from transformers import BartForConditionalGeneration, BartTokenizer - -model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0) -tok = BartTokenizer.from_pretrained("facebook/bart-large") -example_english_phrase = "UN Chief Says There Is No in Syria" -batch = tok(example_english_phrase, return_tensors="pt") -generated_ids = model.generate(batch["input_ids"]) -assert tok.batch_decode(generated_ids, skip_special_tokens=True) == [ - "UN Chief Says There Is No Plan to Stop Chemical Weapons in Syria" -] +print(f"The predicted token is: {predicted_token}") ``` -## Resources + + -A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BART. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. +```bash +echo -e "Plants create through a process known as photosynthesis." | transformers-cli run --task fill-mask --model facebook/bart-large --device 0 +``` - + + -- A blog post on [Distributed Training: Train BART/T5 for Summarization using πŸ€— Transformers and Amazon SageMaker](https://huggingface.co/blog/sagemaker-distributed-training-seq2seq). -- A notebook on how to [finetune BART for summarization with fastai using blurr](https://colab.research.google.com/github/ohmeow/ohmeow_website/blob/master/posts/2021-05-25-mbart-sequence-classification-with-blurr.ipynb). 🌎 -- A notebook on how to [finetune BART for summarization in two languages with Trainer class](https://colab.research.google.com/github/elsanns/xai-nlp-notebooks/blob/master/fine_tune_bart_summarization_two_langs.ipynb). 🌎 -- [`BartForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization.ipynb). -- [`TFBartForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/summarization) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization-tf.ipynb). -- [`FlaxBartForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization). -- An example of how to train [`BartForConditionalGeneration`] with a Hugging Face `datasets` object can be found in this [forum discussion](https://discuss.huggingface.co/t/train-bart-for-conditional-generation-e-g-summarization/1904) -- [Summarization](https://huggingface.co/course/chapter7/5?fw=pt#summarization) chapter of the πŸ€— Hugging Face course. -- [Summarization task guide](../tasks/summarization) +## Notes - - -- [`BartForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#robertabertdistilbert-and-masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb). -- [`TFBartForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/language-modeling#run_mlmpy) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb). -- [`FlaxBartForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/masked_language_modeling_flax.ipynb). -- [Masked language modeling](https://huggingface.co/course/chapter7/3?fw=pt) chapter of the πŸ€— Hugging Face Course. -- [Masked language modeling task guide](../tasks/masked_language_modeling) - - - -- A notebook on how to [finetune mBART using Seq2SeqTrainer for Hindi to English translation](https://colab.research.google.com/github/vasudevgupta7/huggingface-tutorials/blob/main/translation_training.ipynb). 🌎 -- [`BartForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/translation) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/translation.ipynb). -- [`TFBartForConditionalGeneration`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/translation) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/translation-tf.ipynb). -- [Translation task guide](../tasks/translation) - -See also: -- [Text classification task guide](../tasks/sequence_classification) -- [Question answering task guide](../tasks/question_answering) -- [Causal language modeling task guide](../tasks/language_modeling) -- [Distilled checkpoints](https://huggingface.co/models?search=distilbart) are described in this [paper](https://arxiv.org/abs/2010.13002). +- Inputs should be padded on the right because BERT uses absolute position embeddings. +- The [facebook/bart-large-cnn](https://huggingface.co/facebook/bart-large-cnn) checkpoint doesn't include `mask_token_id` which means it can't perform mask-filling tasks. +- BART doesn’t use `token_type_ids` for sequence classification. Use [`BartTokenizer`] or [`~PreTrainedTokenizerBase.encode`] to get the proper splitting. +- The forward pass of [`BartModel`] creates the `decoder_input_ids` if they're not passed. This can be different from other model APIs, but it is a useful feature for mask-filling tasks. +- Model predictions are intended to be identical to the original implementation when `forced_bos_token_id=0`. This only works if the text passed to `fairseq.encode` begins with a space. +- [`~GenerationMixin.generate`] should be used for conditional generation tasks like summarization. ## BartConfig diff --git a/docs/source/en/model_doc/bertweet.md b/docs/source/en/model_doc/bertweet.md index be489643173..f1f6ff877b0 100644 --- a/docs/source/en/model_doc/bertweet.md +++ b/docs/source/en/model_doc/bertweet.md @@ -16,60 +16,82 @@ rendered properly in your Markdown viewer. # BERTweet -
-PyTorch -TensorFlow -Flax +
+
+ PyTorch + TensorFlow + Flax
-## Overview +## BERTweet -The BERTweet model was proposed in [BERTweet: A pre-trained language model for English Tweets](https://www.aclweb.org/anthology/2020.emnlp-demos.2.pdf) by Dat Quoc Nguyen, Thanh Vu, Anh Tuan Nguyen. +[BERTweet](https://huggingface.co/papers/2005.10200) shares the same architecture as [BERT-base](./bert), but it’s pretrained like [RoBERTa](./roberta) on English Tweets. It performs really well on Tweet-related tasks like part-of-speech tagging, named entity recognition, and text classification. -The abstract from the paper is the following: -*We present BERTweet, the first public large-scale pre-trained language model for English Tweets. Our BERTweet, having -the same architecture as BERT-base (Devlin et al., 2019), is trained using the RoBERTa pre-training procedure (Liu et -al., 2019). Experiments show that BERTweet outperforms strong baselines RoBERTa-base and XLM-R-base (Conneau et al., -2020), producing better performance results than the previous state-of-the-art models on three Tweet NLP tasks: -Part-of-speech tagging, Named-entity recognition and text classification.* +You can find all the original BERTweet checkpoints under the [VinAI Research](https://huggingface.co/vinai?search_models=BERTweet) organization. -This model was contributed by [dqnguyen](https://huggingface.co/dqnguyen). The original code can be found [here](https://github.com/VinAIResearch/BERTweet). +> [!TIP] +> Refer to the [BERT](./bert) docs for more examples of how to apply BERTweet to different language tasks. -## Usage example +The example below demonstrates how to predict the `` token with [`Pipeline`], [`AutoModel`], and from the command line. -```python ->>> import torch ->>> from transformers import AutoModel, AutoTokenizer + + ->>> bertweet = AutoModel.from_pretrained("vinai/bertweet-base") +```py +import torch +from transformers import pipeline ->>> # For transformers v4.x+: ->>> tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", use_fast=False) +pipeline = pipeline( + task="fill-mask", + model="vinai/bertweet-base", + torch_dtype=torch.float16, + device=0 +) +pipeline("Plants create through a process known as photosynthesis.") +``` + + ->>> # For transformers v3.x: ->>> # tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base") +```py +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer ->>> # INPUT TWEET IS ALREADY NORMALIZED! ->>> line = "SC has first two presumptive cases of coronavirus , DHEC confirms HTTPURL via @USER :cry:" +tokenizer = AutoTokenizer.from_pretrained( + "vinai/bertweet-base", +) +model = AutoModelForMaskedLM.from_pretrained( + "vinai/bertweet-base", + torch_dtype=torch.float16, + device_map="auto" +) +inputs = tokenizer("Plants create through a process known as photosynthesis.", return_tensors="pt").to("cuda") ->>> input_ids = torch.tensor([tokenizer.encode(line)]) +with torch.no_grad(): + outputs = model(**inputs) + predictions = outputs.logits ->>> with torch.no_grad(): -... features = bertweet(input_ids) # Models outputs are now tuples +masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1] +predicted_token_id = predictions[0, masked_index].argmax(dim=-1) +predicted_token = tokenizer.decode(predicted_token_id) ->>> # With TensorFlow 2.0+: ->>> # from transformers import TFAutoModel ->>> # bertweet = TFAutoModel.from_pretrained("vinai/bertweet-base") +print(f"The predicted token is: {predicted_token}") ``` - + + -This implementation is the same as BERT, except for tokenization method. Refer to [BERT documentation](bert) for -API reference information. +```bash +echo -e "Plants create through a process known as photosynthesis." | transformers-cli run --task fill-mask --model vinai/bertweet-base --device 0 +``` - + + + +## Notes +- Use the [`AutoTokenizer`] or [`BertweetTokenizer`] because it’s preloaded with a custom vocabulary adapted to tweet-specific tokens like hashtags (#), mentions (@), emojis, and common abbreviations. Make sure to also install the [emoji](https://pypi.org/project/emoji/) library. +- Inputs should be padded on the right (`padding="max_length"`) because BERT uses absolute position embeddings. ## BertweetTokenizer diff --git a/docs/source/en/model_doc/big_bird.md b/docs/source/en/model_doc/big_bird.md index 32ca5a2062a..16e1a3bff84 100644 --- a/docs/source/en/model_doc/big_bird.md +++ b/docs/source/en/model_doc/big_bird.md @@ -14,63 +14,87 @@ rendered properly in your Markdown viewer. --> -# BigBird - -
-PyTorch -Flax +
+ PyTorch + Flax +
-## Overview +# BigBird -The BigBird model was proposed in [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by -Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, -Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention -based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse -attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it -has been shown that applying sparse, global, and random attention approximates full attention, while being -computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, -BigBird has shown improved performance on various long document NLP tasks, such as question answering and -summarization, compared to BERT or RoBERTa. +[BigBird](https://huggingface.co/papers/2007.14062) is a transformer model built to handle sequence lengths up to 4096 compared to 512 for [BERT](./bert). Traditional transformers struggle with long inputs because attention gets really expensive as the sequence length grows. BigBird fixes this by using a sparse attention mechanism, which means it doesn’t try to look at everything at once. Instead, it mixes in local attention, random attention, and a few global tokens to process the whole input. This combination gives it the best of both worlds. It keeps the computation efficient while still capturing enough of the sequence to understand it well. Because of this, BigBird is great at tasks involving long documents, like question answering, summarization, and genomic applications. -The abstract from the paper is the following: -*Transformers-based models, such as BERT, have been one of the most successful deep learning models for NLP. -Unfortunately, one of their core limitations is the quadratic dependency (mainly in terms of memory) on the sequence -length due to their full attention mechanism. To remedy this, we propose, BigBird, a sparse attention mechanism that -reduces this quadratic dependency to linear. We show that BigBird is a universal approximator of sequence functions and -is Turing complete, thereby preserving these properties of the quadratic, full attention model. Along the way, our -theoretical analysis reveals some of the benefits of having O(1) global tokens (such as CLS), that attend to the entire -sequence as part of the sparse attention mechanism. The proposed sparse attention can handle sequences of length up to -8x of what was previously possible using similar hardware. As a consequence of the capability to handle longer context, -BigBird drastically improves performance on various NLP tasks such as question answering and summarization. We also -propose novel applications to genomics data.* +You can find all the original BigBird checkpoints under the [Google](https://huggingface.co/google?search_models=bigbird) organization. -This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta). The original code can be found -[here](https://github.com/google-research/bigbird). +> [!TIP] +> Click on the BigBird models in the right sidebar for more examples of how to apply BigBird to different language tasks. -## Usage tips +The example below demonstrates how to predict the `[MASK]` token with [`Pipeline`], [`AutoModel`], and from the command line. -- For an in-detail explanation on how BigBird's attention works, see [this blog post](https://huggingface.co/blog/big-bird). -- BigBird comes with 2 implementations: **original_full** & **block_sparse**. For the sequence length < 1024, using - **original_full** is advised as there is no benefit in using **block_sparse** attention. -- The code currently uses window size of 3 blocks and 2 global blocks. -- Sequence length must be divisible by block size. -- Current implementation supports only **ITC**. -- Current implementation doesn't support **num_random_blocks = 0** -- BigBird is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than - the left. + + +```py +import torch +from transformers import pipeline + +pipeline = pipeline( + task="fill-mask", + model="google/bigbird-roberta-base", + torch_dtype=torch.float16, + device=0 +) +pipeline("Plants create [MASK] through a process known as photosynthesis.") +``` + + + +```py +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained( + "google/bigbird-roberta-base", +) +model = AutoModelForMaskedLM.from_pretrained( + "google/bigbird-roberta-base", + torch_dtype=torch.float16, + device_map="auto", +) +inputs = tokenizer("Plants create [MASK] through a process known as photosynthesis.", return_tensors="pt").to("cuda") + +with torch.no_grad(): + outputs = model(**inputs) + predictions = outputs.logits + +masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1] +predicted_token_id = predictions[0, masked_index].argmax(dim=-1) +predicted_token = tokenizer.decode(predicted_token_id) + +print(f"The predicted token is: {predicted_token}") +``` + + + + +```bash +!echo -e "Plants create [MASK] through a process known as photosynthesis." | transformers-cli run --task fill-mask --model google/bigbird-roberta-base --device 0 +``` + + + +## Notes +- Inputs should be padded on the right because BigBird uses absolute position embeddings. +- BigBird supports `original_full` and `block_sparse` attention. If the input sequence length is less than 1024, it is recommended to use `original_full` since sparse patterns don't offer much benefit for smaller inputs. +- The current implementation uses window size of 3 blocks and 2 global blocks, only supports the ITC-implementation, and doesn't support `num_random_blocks=0`. +- The sequence length must be divisible by the block size. ## Resources -- [Text classification task guide](../tasks/sequence_classification) -- [Token classification task guide](../tasks/token_classification) -- [Question answering task guide](../tasks/question_answering) -- [Causal language modeling task guide](../tasks/language_modeling) -- [Masked language modeling task guide](../tasks/masked_language_modeling) -- [Multiple choice task guide](../tasks/multiple_choice) +- Read the [BigBird](https://huggingface.co/blog/big-bird) blog post for more details about how its attention works. ## BigBirdConfig diff --git a/docs/source/en/model_doc/biogpt.md b/docs/source/en/model_doc/biogpt.md index d7145993a89..0b6eb877647 100644 --- a/docs/source/en/model_doc/biogpt.md +++ b/docs/source/en/model_doc/biogpt.md @@ -14,78 +14,121 @@ rendered properly in your Markdown viewer. --> -# BioGPT - -
-PyTorch -SDPA +
+
+ PyTorch + FlashAttention + SDPA +
-## Overview +# BioGPT -The BioGPT model was proposed in [BioGPT: generative pre-trained transformer for biomedical text generation and mining](https://academic.oup.com/bib/advance-article/doi/10.1093/bib/bbac409/6713511?guestAccessKey=a66d9b5d-4f83-4017-bb52-405815c907b9) by Renqian Luo, Liai Sun, Yingce Xia, Tao Qin, Sheng Zhang, Hoifung Poon and Tie-Yan Liu. BioGPT is a domain-specific generative pre-trained Transformer language model for biomedical text generation and mining. BioGPT follows the Transformer language model backbone, and is pre-trained on 15M PubMed abstracts from scratch. +[BioGPT](https://huggingface.co/papers/2210.10341) is a generative Transformer model based on [GPT-2](./gpt2) and pretrained on 15 million PubMed abstracts. It is designed for biomedical language tasks. -The abstract from the paper is the following: +You can find all the original BioGPT checkpoints under the [Microsoft](https://huggingface.co/microsoft?search_models=biogpt) organization. -*Pre-trained language models have attracted increasing attention in the biomedical domain, inspired by their great success in the general natural language domain. Among the two main branches of pre-trained language models in the general language domain, i.e. BERT (and its variants) and GPT (and its variants), the first one has been extensively studied in the biomedical domain, such as BioBERT and PubMedBERT. While they have achieved great success on a variety of discriminative downstream biomedical tasks, the lack of generation ability constrains their application scope. In this paper, we propose BioGPT, a domain-specific generative Transformer language model pre-trained on large-scale biomedical literature. We evaluate BioGPT on six biomedical natural language processing tasks and demonstrate that our model outperforms previous models on most tasks. Especially, we get 44.98%, 38.42% and 40.76% F1 score on BC5CDR, KD-DTI and DDI end-to-end relation extraction tasks, respectively, and 78.2% accuracy on PubMedQA, creating a new record. Our case study on text generation further demonstrates the advantage of BioGPT on biomedical literature to generate fluent descriptions for biomedical terms.* +> [!TIP] +> Click on the BioGPT models in the right sidebar for more examples of how to apply BioGPT to different language tasks. -This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code can be found [here](https://github.com/microsoft/BioGPT). +The example below demonstrates how to generate biomedical text with [`Pipeline`], [`AutoModel`], and also from the command line. -## Usage tips + + -- BioGPT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. -- BioGPT was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next token in a sequence. Leveraging this feature allows BioGPT to generate syntactically coherent text as it can be observed in the run_generation.py example script. -- The model can take the `past_key_values` (for PyTorch) as input, which is the previously computed key/value attention pairs. Using this (past_key_values or past) value prevents the model from re-computing pre-computed values in the context of text generation. For PyTorch, see past_key_values argument of the BioGptForCausalLM.forward() method for more information on its usage. -- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")` +```py +import torch +from transformers import pipeline -### Using Scaled Dot Product Attention (SDPA) - -PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function -encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the -[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) -or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) -page for more information. - -SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set -`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. - -``` -from transformers import BioGptForCausalLM -model = BioGptForCausalLM.from_pretrained("microsoft/biogpt", attn_implementation="sdpa", torch_dtype=torch.float16) +generator = pipeline( + task="text-generation", + model="microsoft/biogpt", + torch_dtype=torch.float16, + device=0, +) +result = generator("Ibuprofen is best used for", truncation=True, max_length=50, do_sample=True)[0]["generated_text"] +print(result) ``` -On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and `microsoft/biogpt` model with a CausalLM head, -we saw the following speedups during training. + + -For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer -| num_training_steps | batch_size | seq_len | is cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) | -|--------------------|------------|---------|---------|----------------------------|---------------------------|-------------|---------------------|--------------------|----------------| -| 100 | 1 | 128 | False | 0.038 | 0.031 | 21.301 | 1601.862 | 1601.497 | 0.023 | -| 100 | 1 | 256 | False | 0.039 | 0.034 | 15.084 | 1624.944 | 1625.296 | -0.022 | -| 100 | 2 | 128 | False | 0.039 | 0.033 | 16.820 | 1624.567 | 1625.296 | -0.045 | -| 100 | 2 | 256 | False | 0.065 | 0.059 | 10.255 | 1672.164 | 1672.164 | 0.000 | -| 100 | 4 | 128 | False | 0.062 | 0.058 | 6.998 | 1671.435 | 1672.164 | -0.044 | -| 100 | 4 | 256 | False | 0.113 | 0.100 | 13.316 | 2350.179 | 1848.435 | 27.144 | -| 100 | 8 | 128 | False | 0.107 | 0.098 | 9.883 | 2098.521 | 1848.435 | 13.530 | -| 100 | 8 | 256 | False | 0.222 | 0.196 | 13.413 | 3989.980 | 2986.492 | 33.601 | +tokenizer = AutoTokenizer.from_pretrained("microsoft/biogpt") +model = AutoModelForCausalLM.from_pretrained( + "microsoft/biogpt", + torch_dtype=torch.float16, + device_map="auto", + attn_implementation="sdpa" +) -On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.3.1, OS Ubuntu 20.04) with `float16` and `microsoft/biogpt` model with a simple AutoModel head, -we saw the following speedups during inference. +input_text = "Ibuprofen is best used for" +inputs = tokenizer(input_text, return_tensors="pt").to(model.device) -| num_batches | batch_size | seq_len | is cuda | is half | use mask | Per token latency eager (ms) | Per token latency SDPA (ms) | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) | -|-------------|------------|---------|---------|---------|----------|------------------------------|-----------------------------|-------------|----------------|--------------|---------------| -| 50 | 1 | 64 | True | True | True | 0.115 | 0.098 | 17.392 | 716.998 | 716.998 | 0.000 | -| 50 | 1 | 128 | True | True | True | 0.115 | 0.093 | 24.640 | 730.916 | 730.916 | 0.000 | -| 50 | 2 | 64 | True | True | True | 0.114 | 0.096 | 19.204 | 730.900 | 730.900 | 0.000 | -| 50 | 2 | 128 | True | True | True | 0.117 | 0.095 | 23.529 | 759.262 | 759.262 | 0.000 | -| 50 | 4 | 64 | True | True | True | 0.113 | 0.096 | 18.325 | 759.229 | 759.229 | 0.000 | -| 50 | 4 | 128 | True | True | True | 0.186 | 0.178 | 4.289 | 816.478 | 816.478 | 0.000 | +with torch.no_grad(): + generated_ids = model.generate(**inputs, max_length=50) + +output = tokenizer.decode(generated_ids[0], skip_special_tokens=True) +print(output) +``` + + -## Resources +```bash +echo -e "Ibuprofen is best used for" | transformers-cli run --task text-generation --model microsoft/biogpt --device 0 +``` -- [Causal language modeling task guide](../tasks/language_modeling) + + + +Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. + +The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to 4-bit precision. + +```py +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig + +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True +) + +tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT-Large") +model = AutoModelForCausalLM.from_pretrained( + "microsoft/BioGPT-Large", + quantization_config=bnb_config, + torch_dtype=torch.bfloat16, + device_map="auto" +) + +input_text = "Ibuprofen is best used for" +inputs = tokenizer(input_text, return_tensors="pt").to(model.device) +with torch.no_grad(): + generated_ids = model.generate(**inputs, max_length=50) +output = tokenizer.decode(generated_ids[0], skip_special_tokens=True) +print(output) +``` + +## Notes + +- Pad inputs on the right because BioGPT uses absolute position embeddings. +- BioGPT can reuse previously computed key-value attention pairs. Access this feature with the [past_key_values](https://huggingface.co/docs/transformers/main/en/model_doc/biogpt#transformers.BioGptModel.forward.past_key_values) parameter in [`BioGPTModel.forward`]. +- The `head_mask` argument is ignored when using an attention implementation other than "eager". If you want to use `head_mask`, make sure `attn_implementation="eager"`). + + ```py + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained( + "microsoft/biogpt", + attn_implementation="eager" + ) ## BioGptConfig @@ -109,7 +152,7 @@ we saw the following speedups during inference. [[autodoc]] BioGptForCausalLM - forward - + ## BioGptForTokenClassification [[autodoc]] BioGptForTokenClassification diff --git a/docs/source/en/model_doc/blenderbot-small.md b/docs/source/en/model_doc/blenderbot-small.md index 647a865de33..341e43c0304 100644 --- a/docs/source/en/model_doc/blenderbot-small.md +++ b/docs/source/en/model_doc/blenderbot-small.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA
Note that [`BlenderbotSmallModel`] and @@ -52,7 +54,7 @@ found [here](https://github.com/facebookresearch/ParlAI). ## Usage tips -Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than +Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. diff --git a/docs/source/en/model_doc/blenderbot.md b/docs/source/en/model_doc/blenderbot.md index ec24d5ed749..adfa6841e10 100644 --- a/docs/source/en/model_doc/blenderbot.md +++ b/docs/source/en/model_doc/blenderbot.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA
## Overview @@ -45,7 +47,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The ## Usage tips and example -Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right +Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. An example: @@ -71,7 +73,7 @@ An example: `facebook/blenderbot_small_90M`, have a different architecture and consequently should be used with [BlenderbotSmall](blenderbot-small). - + ## Resources - [Causal language modeling task guide](../tasks/language_modeling) diff --git a/docs/source/en/model_doc/csm.md b/docs/source/en/model_doc/csm.md index 2d916da161f..833ddb697b5 100644 --- a/docs/source/en/model_doc/csm.md +++ b/docs/source/en/model_doc/csm.md @@ -39,7 +39,7 @@ CSM can be used to simply generate speech from a text prompt: import torch from transformers import CsmForConditionalGeneration, AutoProcessor -model_id = "eustlb/csm-1b" +model_id = "sesame/csm-1b" device = "cuda" if torch.cuda.is_available() else "cpu" # load the model and the processor @@ -74,7 +74,7 @@ import torch from transformers import CsmForConditionalGeneration, AutoProcessor from datasets import load_dataset, Audio -model_id = "eustlb/csm-1b" +model_id = "sesame/csm-1b" device = "cuda" if torch.cuda.is_available() else "cpu" # load the model and the processor @@ -119,7 +119,7 @@ import torch from transformers import CsmForConditionalGeneration, AutoProcessor from datasets import load_dataset, Audio -model_id = "eustlb/csm-1b" +model_id = "sesame/csm-1b" device = "cuda" if torch.cuda.is_available() else "cpu" # load the model and the processor @@ -176,7 +176,7 @@ import copy from transformers import CsmForConditionalGeneration, AutoProcessor from datasets import load_dataset -model_id = "eustlb/csm-1b" +model_id = "sesame/csm-1b" device = "cuda" # set logs to ensure no recompilation and graph breaks @@ -308,13 +308,14 @@ CSM Transformers integration supports training! from transformers import CsmForConditionalGeneration, AutoProcessor from datasets import load_dataset, Audio -model_id = "eustlb/csm-1b" +model_id = "sesame/csm-1b" device = "cuda" # load the model and the processor processor = AutoProcessor.from_pretrained(model_id) model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device) model.train() +model.codec_model.eval() ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") # ensure the audio is 24kHz @@ -355,6 +356,10 @@ The original code can be found [here](https://github.com/SesameAILabs/csm). ## CsmProcessor +
+ +
+ [[autodoc]] CsmProcessor - __call__ diff --git a/docs/source/en/model_doc/deepseek_v3.md b/docs/source/en/model_doc/deepseek_v3.md index c3322a102f6..ae2bb42a625 100644 --- a/docs/source/en/model_doc/deepseek_v3.md +++ b/docs/source/en/model_doc/deepseek_v3.md @@ -28,8 +28,8 @@ We present DeepSeek-V3, a strong Mixture-of-Experts (MoE) language model with 67 We are super happy to make this code community-powered, and would love to see how you can best optimize the following: - current implementation uses the "naive" attention compution (so not really MLA) -- current implementation loops through the experts. This should be replaced. Pointers to use `get_packed_weights` from `intetrations/tensor_parallel`. -- current implementation uses the eleuther formula for ROPE, using the orginal one would be more efficient! (should still follow our API) +- current implementation loops through the experts. This should be replaced. Pointers to use `get_packed_weights` from `integrations/tensor_parallel`. +- current implementation uses the eleuther formula for ROPE, using the original one would be more efficient! (should still follow our API) - static cache is not supported (this should be just a generation config issue / config shape issues) ### Usage tips diff --git a/docs/source/en/model_doc/granite.md b/docs/source/en/model_doc/granite.md index 0326bc5ad24..0f54db1bd2e 100644 --- a/docs/source/en/model_doc/granite.md +++ b/docs/source/en/model_doc/granite.md @@ -9,12 +9,11 @@ Unless required by applicable law or agreed to in writing, software distributed an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be +⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer. --> -# Granite
PyTorch @@ -22,49 +21,94 @@ rendered properly in your Markdown viewer. SDPA
-## Overview +# Granite -The Granite model was proposed in [Power Scheduler: A Batch Size and Token Number Agnostic Learning Rate Scheduler](https://arxiv.org/abs/2408.13359) by Yikang Shen, Matthew Stallone, Mayank Mishra, Gaoyuan Zhang, Shawn Tan, Aditya Prasad, Adriana Meza Soria, David D. Cox and Rameswar Panda. +[Granite](https://huggingface.co/papers/2408.13359) is a 3B parameter language model trained with the Power scheduler. Discovering a good learning rate for pretraining large language models is difficult because it depends on so many variables (batch size, number of training tokens, etc.) and it is expensive to perform a hyperparameter search. The Power scheduler is based on a power-law relationship between the variables and their transferability to larger models. Combining the Power scheduler with Maximum Update Parameterization (MUP) allows a model to be pretrained with one set of hyperparameters regardless of all the variables. -PowerLM-3B is a 3B state-of-the-art small language model trained with the Power learning rate scheduler. It is trained on a wide range of open-source and synthetic datasets with permissive licenses. PowerLM-3B has shown promising results compared to other models in the size categories across various benchmarks, including natural language multi-choices, code generation, and math reasoning. +You can find all the original Granite checkpoints under the [IBM-Granite](https://huggingface.co/ibm-granite) organization. -The abstract from the paper is the following: +> [!TIP] +> Click on the Granite models in the right sidebar for more examples of how to apply Granite to different language tasks. -*Finding the optimal learning rate for language model pretraining is a challenging task. -This is not only because there is a complicated correlation between learning rate, batch size, number of training tokens, model size, and other hyperparameters but also because it is prohibitively expensive to perform a hyperparameter search for large language models with Billions or Trillions of parameters. Recent studies propose using small proxy models and small corpus to perform hyperparameter searches and transposing the optimal parameters to large models and large corpus. While the zero-shot transferability is theoretically and empirically proven for model size related hyperparameters, like depth and width, the zero-shot transfer from small corpus to large corpus is underexplored. -In this paper, we study the correlation between optimal learning rate, batch size, and number of training tokens for the recently proposed WSD scheduler. After thousands of small experiments, we found a power-law relationship between variables and demonstrated its transferability across model sizes. Based on the observation, we propose a new learning rate scheduler, Power scheduler, that is agnostic about the number of training tokens and batch size. The experiment shows that combining the Power scheduler with Maximum Update Parameterization (\mup) can consistently achieve impressive performance with one set of hyperparameters regardless of the number of training tokens, batch size, model size, and even model architecture. Our 3B dense and MoE models trained with the Power scheduler achieve comparable performance as state-of-the-art small language models. -We [open source](https://huggingface.co/collections/ibm/power-lm-66be64ae647ddf11b9808000) these pretrained models.* +The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`, and from the command line. -Tips: + + + +```python +import torch +from transformers import pipeline + +pipe = pipeline( + task="text-generation", + model="ibm-granite/granite-3.3-2b-base", + torch_dtype=torch.bfloat16, + device=0 +) +pipe("Explain quantum computing in simple terms ", max_new_tokens=50) +``` + + + ```python import torch from transformers import AutoModelForCausalLM, AutoTokenizer -model_path = "ibm/PowerLM-3b" -tokenizer = AutoTokenizer.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.3-2b-base") +model = AutoModelForCausalLM.from_pretrained( + "ibm-granite/granite-3.3-2b-base", + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="sdpa" +) -# drop device_map if running on CPU -model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") -model.eval() +inputs = tokenizer("Explain quantum computing in simple terms", return_tensors="pt").to("cuda") +outputs = model.generate(**inputs, max_length=50, cache_implementation="static") +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + + -# change input text as desired -prompt = "Write a code to find the maximum value in a list of numbers." +```python +echo -e "Explain quantum computing simply." | transformers-cli run --task text-generation --model ibm-granite/granite-3.3-8b-instruct --device 0 +``` + + -# tokenize the text -input_tokens = tokenizer(prompt, return_tensors="pt") -# generate output tokens -output = model.generate(**input_tokens, max_new_tokens=100) -# decode output tokens into text -output = tokenizer.batch_decode(output) -# loop over the batch to print, in this example the batch size is 1 -for i in output: - print(i) +Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. + +The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4. + +```python +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) +tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.3-8b-base") +model = AutoModelForCausalLM.from_pretrained("ibm-granite/granite-3.3-8b-base", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa", quantization_config=quantization_config) + +inputs = tokenizer("Explain quantum computing in simple terms", return_tensors="pt").to("cuda") +outputs = model.generate(**inputs, max_length=50, cache_implementation="static") +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +tokenizer = AutoTokenizer.from_pretrained(""ibm-granite/granite-3.3-2b-base"") +model = AutoModelForCausalLM.from_pretrained( + "ibm-granite/granite-3.3-2b-base", + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="sdpa", + quantization_config=quantization_config, +) + +input_ids = tokenizer("Explain artificial intelligence to a 10 year old", return_tensors="pt").to("cuda") +outputs = model.generate(**inputs, max_length=50, cache_implementation="static") +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` -This model was contributed by [mayank-mishra](https://huggingface.co/mayank-mishra). - - + ## GraniteConfig [[autodoc]] GraniteConfig diff --git a/docs/source/en/model_doc/jamba.md b/docs/source/en/model_doc/jamba.md index a096f238418..5dad796f260 100644 --- a/docs/source/en/model_doc/jamba.md +++ b/docs/source/en/model_doc/jamba.md @@ -99,7 +99,7 @@ quantization_config = BitsAndBytesConfig(load_in_8bit=True, device_map = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0, 'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 1, 'model.layers.10': 1, 'model.layers.11': 1, 'model.layers.12': 1, 'model.layers.13': 1, 'model.layers.14': 1, 'model.layers.15': 1, 'model.layers.16': 1, 'model.layers.17': 1, 'model.layers.18': 2, 'model.layers.19': 2, 'model.layers.20': 2, 'model.layers.21': 2, 'model.layers.22': 2, 'model.layers.23': 2, 'model.layers.24': 2, 'model.layers.25': 2, 'model.layers.26': 2, 'model.layers.27': 3, 'model.layers.28': 3, 'model.layers.29': 3, 'model.layers.30': 3, 'model.layers.31': 3, 'model.layers.32': 3, 'model.layers.33': 3, 'model.layers.34': 3, 'model.layers.35': 3, 'model.layers.36': 4, 'model.layers.37': 4, 'model.layers.38': 4, 'model.layers.39': 4, 'model.layers.40': 4, 'model.layers.41': 4, 'model.layers.42': 4, 'model.layers.43': 4, 'model.layers.44': 4, 'model.layers.45': 5, 'model.layers.46': 5, 'model.layers.47': 5, 'model.layers.48': 5, 'model.layers.49': 5, 'model.layers.50': 5, 'model.layers.51': 5, 'model.layers.52': 5, 'model.layers.53': 5, 'model.layers.54': 6, 'model.layers.55': 6, 'model.layers.56': 6, 'model.layers.57': 6, 'model.layers.58': 6, 'model.layers.59': 6, 'model.layers.60': 6, 'model.layers.61': 6, 'model.layers.62': 6, 'model.layers.63': 7, 'model.layers.64': 7, 'model.layers.65': 7, 'model.layers.66': 7, 'model.layers.67': 7, 'model.layers.68': 7, 'model.layers.69': 7, 'model.layers.70': 7, 'model.layers.71': 7, 'model.final_layernorm': 7, 'lm_head': 7} model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-Large-1.6", torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", + attn_implementation="flash_attention_2", quantization_config=quantization_config, device_map=device_map) diff --git a/docs/source/en/model_doc/mamba2.md b/docs/source/en/model_doc/mamba2.md index 8d88d6c0265..5a577983a74 100644 --- a/docs/source/en/model_doc/mamba2.md +++ b/docs/source/en/model_doc/mamba2.md @@ -14,47 +14,94 @@ rendered properly in your Markdown viewer. --> +
+
+ PyTorch +
+ # Mamba 2 -
-PyTorch -
+[Mamba 2](https://huggingface.co/papers/2405.21060) is based on the state space duality (SSD) framework which connects structured state space models (SSMs) and attention variants. It uses a more efficient SSD algorithm that is 2-8x faster than Mamba and modifies the architecture to enable tensor parallelism and a grouped-value attention (GVA) head structure. -## Overview +You can find all the original Mamba 2 checkpoints under the [State Space Models](https://huggingface.co/state-spaces) organization, but the examples shown below use [mistralai/Mamba-Codestral-7B-v0.1](https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1) because a Hugging Face implementation isn't supported yet for the original checkpoints. -The Mamba2 model was proposed in [Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality](https://arxiv.org/abs/2405.21060) by Tri Dao and Albert Gu. It is a State Space Model similar to Mamba 1, with better performances in a simplified architecture. +> [!TIP] +> Click on the Mamba models in the right sidebar for more examples of how to apply Mamba to different language tasks. +The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line. -The abstract from the paper is the following: +hfoptions id="usage"> + -*While Transformers have been the main architecture behind deep learning's success in language modeling, state-space models (SSMs) such as Mamba have recently been shown to match or outperform Transformers at small to medium scale. We show that these families of models are actually quite closely related, and develop a rich framework of theoretical connections between SSMs and variants of attention, connected through various decompositions of a well-studied class of structured semiseparable matrices. Our state space duality (SSD) framework allows us to design a new architecture (Mamba-2) whose core layer is an a refinement of Mamba's selective SSM that is 2-8X faster, while continuing to be competitive with Transformers on language modeling.* - -Tips: - -This version should support all implementations of Mamba 2, and in particular [Mamba-2 codestral](https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1) from Mistral AI. In particular, mamba 2 codestral was released with a number of `groups` equal to 8, which can be thought intuitively as similar to the number of kv heads in an attention-based model. -This model has two different forward passes, `torch_forward` or `cuda_kernels_forward`. The latter uses the original cuda kernels if they are found in your environment, and is slower on the prefill i.e. requires a "warmup run" due to high cpu overhead, see [here](https://github.com/state-spaces/mamba/issues/389#issuecomment-2171755306) and [also here](https://github.com/state-spaces/mamba/issues/355#issuecomment-2147597457). Without compilation, the `torch_forward` implementation is faster by a factor 3 to 4. Further, there are no positional embeddings in this model, but there is an `attention_mask` and a specific logic to mask out hidden states in two places in the case of batched generation, see [here](https://github.com/state-spaces/mamba/issues/66#issuecomment-1863563829) as well. Due to this, in addition to the reimplementation of mamba2 kernels, batched generation and cached generation are expected to have slight discrepancies. Further, the results given by the cuda kernels or the torch forward are expected to be slightly different. The SSM algorithm heavily relies on tensor contractions, which have matmul equivalents but the order of operations is slightly different, making the difference greater at smaller precisions. -Another note, shutdown of hidden states corresponding to padding tokens is done in 2 places and mostly has been tested with left-padding. Right-padding will propagate noise down the line and is not guaranteed to yield satisfactory results. `tokenizer.padding_side = "left"` ensures you are using the correct padding side. - -This model was contributed by [Molbap](https://huggingface.co/Molbap), with tremendous help from [Anton Vlasjuk](https://github.com/vasqu). -The original code can be found [here](https://github.com/state-spaces/mamba). - - -# Usage - -### A simple generation example: -```python -from transformers import Mamba2Config, Mamba2ForCausalLM, AutoTokenizer +```python import torch -model_id = 'mistralai/Mamba-Codestral-7B-v0.1' -tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False) -model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9') -input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"] +from transformers import pipeline -out = model.generate(input_ids, max_new_tokens=10) -print(tokenizer.batch_decode(out)) +pipeline = pipeline( + task="text-generation", + model="mistralai/Mamba-Codestral-7B-v0.1", + torch_dtype=torch.bfloat16, + device=0 +) +pipeline("Plants create energy through a process known as") ``` -Here's a draft script for finetuning: + + + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1") +model = AutoModelForCausalLM.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", torch_dtype=torch.bfloat16, device_map="auto") +input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda") + +output = model.generate(**input_ids) +print(tokenizer.decode(output[0], skip_special_tokens=True)) +``` + + + + +```bash +echo -e "Plants create energy through a process known as" | transformers-cli run --task text-generation --model mistralai/Mamba-Codestral-7B-v0.1 --device 0 +``` + + + + +Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. + +The example below uses [torchao](../quantization/torchao) to only quantize the weights to 4-bit integers. + +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +quantization_config = TorchAoConfig("int4_weight_only", group_size=128) +tokenizer = AutoTokenizer.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1") +model = AutoModelForCausalLM.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto") +input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda") + +output = model.generate(**input_ids) +print(tokenizer.decode(output[0], skip_special_tokens=True)) +``` +## Notes + +- Codestral Mamba has `groups=8` which are similar to the number of kv heads in an attention-based model. +- Codestral Mamba has two different forward passes, `torch_forward` or `cuda_kernels_forward`, and their results are expected to be slightly different. + - `torch_forward` without compilation is 3-4x faster than `cuda_kernels_forward`. + - `cuda_kernels_forward` uses the original CUDA kernels if they're available in your environment. It is slower during prefill because it requires a "warmup run" due to the higher CPU overhead (see [these](https://github.com/state-spaces/mamba/issues/389#issuecomment-2171755306) [comments](https://github.com/state-spaces/mamba/issues/355#issuecomment-2147597457) for more details). + +- There are no positional embeddings in this model, but there is an `attention_mask` and a specific logic to mask out hidden states in two places in the case of batched generation (see this [comment](https://github.com/state-spaces/mamba/issues/66#issuecomment-1863563829) for more details). This (and the addition of the reimplemented Mamba 2 kernels) results in a slight discrepancy between batched and cached generation. + +- The SSM algorithm heavily relies on tensor contractions, which have matmul equivalents but the order of operations is slightly different. This makes the difference greater at smaller precisions. + +- Hidden states that correspond to padding tokens is shutdown in 2 places and is mostly tested with left-padding. Right-padding propagates noise down the line and is not guaranteed to yield satisfactory results. `tokenizer.padding_side = "left"` ensures you are using the correct padding side. + +- The example below demonstrates how to fine-tune Mamba 2 with [PEFT](https://huggingface.co/docs/peft). + ```python from trl import SFTTrainer from peft import LoraConfig diff --git a/docs/source/en/model_doc/marian.md b/docs/source/en/model_doc/marian.md index 80bb73d26df..4fcd6363559 100644 --- a/docs/source/en/model_doc/marian.md +++ b/docs/source/en/model_doc/marian.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA
## Overview @@ -155,7 +157,7 @@ Example of translating english to many romance languages, using old-style 2 char >>> model = MarianMTModel.from_pretrained(model_name) >>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True)) >>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] -["c'est une phrase en anglais que nous voulons traduire en franΓ§ais", +["c'est une phrase en anglais que nous voulons traduire en franΓ§ais", 'Isto deve ir para o portuguΓͺs.', 'Y esto al espaΓ±ol'] ``` diff --git a/docs/source/en/model_doc/nllb-moe.md b/docs/source/en/model_doc/nllb-moe.md index 65a4812ed6a..fc8c8c92115 100644 --- a/docs/source/en/model_doc/nllb-moe.md +++ b/docs/source/en/model_doc/nllb-moe.md @@ -51,10 +51,10 @@ The original code can be found [here](https://github.com/facebookresearch/fairse ## Implementation differences with SwitchTransformers -The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the -highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed, -which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden -states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism. +The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the +highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed, +which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden +states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism. ## Generating with NLLB-MoE diff --git a/docs/source/en/model_doc/olmo2.md b/docs/source/en/model_doc/olmo2.md index 24030b85524..1ed21b660f1 100644 --- a/docs/source/en/model_doc/olmo2.md +++ b/docs/source/en/model_doc/olmo2.md @@ -14,27 +14,119 @@ rendered properly in your Markdown viewer. --> -# OLMo2 - -
-PyTorch -FlashAttention -SDPA +
+
+ PyTorch + FlashAttention + SDPA +
-## Overview +# OLMo2 +[OLMo2](https://huggingface.co/papers/2501.00656) improves on [OLMo](./olmo) by changing the architecture and training recipes of the original models. This includes excluding all biases to improve training stability, non-parametric layer norm, SwiGLU activation function, rotary positional embeddings, and a modified BPE-based tokenizer that masks personal identifiable information. It is pretrained on [Dolma](https://huggingface.co/datasets/allenai/dolma), a dataset of 3T tokens. -The OLMo2 model is the successor of the OLMo model, which was proposed in -[OLMo: Accelerating the Science of Language Models](https://arxiv.org/abs/2402.00838). +You can find all the original OLMo2 checkpoints under the [OLMo2](https://huggingface.co/collections/allenai/olmo-2-674117b93ab84e98afc72edc) collection. - The architectural changes from the original OLMo model to this model are: +> [!TIP] +> Click on the OLMo2 models in the right sidebar for more examples of how to apply OLMo2 to different language tasks. -- RMSNorm is used instead of standard layer norm. -- Norm is applied to attention queries and keys. -- Norm is applied after attention/feedforward layers rather than before. +The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`] and from the command line. -This model was contributed by [shanearora](https://huggingface.co/shanearora). -The original code can be found [here](https://github.com/allenai/OLMo/tree/main/olmo). + + + +```py +import torch +from transformers import pipeline + +pipe = pipeline( + task="text-generation", + model="allenai/OLMo-2-0425-1B", + torch_dtype=torch.float16, + device=0, +) + +result = pipe("Plants create energy through a process known as") +print(result) +``` + + + + +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained( + "allenai/OLMo-2-0425-1B" +) + +model = AutoModelForCausalLM.from_pretrained( + "allenai/OLMo-2-0425-1B", + torch_dtype=torch.float16, + device_map="auto", + attn_implementation="sdpa" +) +input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to(model.device) + +output = model.generate(**input_ids, max_length=50, cache_implementation="static") +print(tokenizer.decode(output[0], skip_special_tokens=True)) +``` + + + + +```bash +echo -e "Plants create energy through a process known as" | transformers-cli run --task text-generation --model allenai/OLMo-2-0425-1B --device 0 +``` + + + + +Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. + +The example below uses [torchao](../quantization/torchao) to only quantize the weights to 4-bits. +```py + +#pip install torchao +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +torchao_config = TorchAoConfig( + "int4_weight_only", + group_size=128 +) + +tokenizer = AutoTokenizer.from_pretrained( + "allenai/OLMo-2-0425-1B" +) + +model = AutoModelForCausalLM.from_pretrained( + "allenai/OLMo-2-0425-1B", + quantization_config=torchao_config, + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="sdpa" +) +input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to(model.device) + +output = model.generate(**input_ids, max_length=50, cache_implementation="static") +print(tokenizer.decode(output[0], skip_special_tokens=True)) + +``` + + +## Notes + +- OLMo2 uses RMSNorm instead of standard layer norm. The RMSNorm is applied to attention queries and keys, and it is applied after the attention and feedforward layers rather than before. +- OLMo2 requires Transformers v4.48 or higher. +- Load specific intermediate checkpoints by adding the `revision` parameter to [`~PreTrainedModel.from_pretrained`]. + + ```py + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B", revision="stage1-step140000-tokens294B") + ``` ## Olmo2Config diff --git a/docs/source/en/model_doc/pegasus.md b/docs/source/en/model_doc/pegasus.md index bdb61e66d98..5681ac9b58a 100644 --- a/docs/source/en/model_doc/pegasus.md +++ b/docs/source/en/model_doc/pegasus.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA
## Overview diff --git a/docs/source/en/model_doc/pegasus_x.md b/docs/source/en/model_doc/pegasus_x.md index 3f982263cdb..97e50601b72 100644 --- a/docs/source/en/model_doc/pegasus_x.md +++ b/docs/source/en/model_doc/pegasus_x.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention
## Overview diff --git a/docs/source/en/model_doc/plbart.md b/docs/source/en/model_doc/plbart.md index bac567615d4..d57ee8ed99e 100644 --- a/docs/source/en/model_doc/plbart.md +++ b/docs/source/en/model_doc/plbart.md @@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention +SDPA
## Overview @@ -29,7 +31,7 @@ on Java, Python and English. According to the abstract *Code summarization and generation empower conversion between programming language (PL) and natural language (NL), -while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART, +while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART, a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks. PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding. Experiments on code summarization in the English language, code generation, and code translation in seven programming languages @@ -50,7 +52,7 @@ target text format is `[tgt_lang_code] X [eos]`. `bos` is never used. However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this. -In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format +In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format when you pass texts as the first argument or with the keyword argument `text`, and will encode target text format if it's passed with the `text_target` keyword argument. diff --git a/docs/source/en/model_doc/roformer.md b/docs/source/en/model_doc/roformer.md index 83d01c2fc91..48c652036e5 100644 --- a/docs/source/en/model_doc/roformer.md +++ b/docs/source/en/model_doc/roformer.md @@ -14,46 +14,78 @@ rendered properly in your Markdown viewer. --> -# RoFormer - -
-PyTorch +
+
+ PyTorch TensorFlow Flax +
-## Overview +# RoFormer -The RoFormer model was proposed in [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. +[RoFormer](https://huggingface.co/papers/2104.09864) introduces Rotary Position Embedding (RoPE) to encode token positions by rotating the inputs in 2D space. This allows a model to track absolute positions and model relative relationships. RoPE can scale to longer sequences, account for the natural decay of token dependencies, and works with the more efficient linear self-attention. -The abstract from the paper is the following: +You can find all the RoFormer checkpoints on the [Hub](https://huggingface.co/models?search=roformer). -*Position encoding in transformer architecture provides supervision for dependency modeling between elements at -different positions in the sequence. We investigate various methods to encode positional information in -transformer-based language models and propose a novel implementation named Rotary Position Embedding(RoPE). The -proposed RoPE encodes absolute positional information with rotation matrix and naturally incorporates explicit relative -position dependency in self-attention formulation. Notably, RoPE comes with valuable properties such as flexibility of -being expand to any sequence lengths, decaying inter-token dependency with increasing relative distances, and -capability of equipping the linear self-attention with relative position encoding. As a result, the enhanced -transformer with rotary position embedding, or RoFormer, achieves superior performance in tasks with long texts. We -release the theoretical analysis along with some preliminary experiment results on Chinese data. The undergoing -experiment for English benchmark will soon be updated.* +> [!TIP] +> Click on the RoFormer models in the right sidebar for more examples of how to apply RoFormer to different language tasks. -This model was contributed by [junnyu](https://huggingface.co/junnyu). The original code can be found [here](https://github.com/ZhuiyiTechnology/roformer). +The example below demonstrates how to predict the `[MASK]` token with [`Pipeline`], [`AutoModel`], and from the command line. -## Usage tips -RoFormer is a BERT-like autoencoding model with rotary position embeddings. Rotary position embeddings have shown -improved performance on classification tasks with long texts. + + -## Resources +```py +# uncomment to install rjieba which is needed for the tokenizer +# !pip install rjieba +import torch +from transformers import pipeline -- [Text classification task guide](../tasks/sequence_classification) -- [Token classification task guide](../tasks/token_classification) -- [Question answering task guide](../tasks/question_answering) -- [Causal language modeling task guide](../tasks/language_modeling) -- [Masked language modeling task guide](../tasks/masked_language_modeling) -- [Multiple choice task guide](../tasks/multiple_choice) +pipe = pipeline( + task="fill-mask", + model="junnyu/roformer_chinese_base", + torch_dtype=torch.float16, + device=0 +) +output = pipe("ζ°΄εœ¨ι›ΆεΊ¦ζ—ΆδΌš[MASK]") +print(output) +``` + + + + +```py +# uncomment to install rjieba which is needed for the tokenizer +# !pip install rjieba +import torch +from transformers import AutoModelForMaskedLM, AutoTokenizer + +model = AutoModelForMaskedLM.from_pretrained( + "junnyu/roformer_chinese_base", torch_dtype=torch.float16 +) +tokenizer = AutoTokenizer.from_pretrained("junnyu/roformer_chinese_base") + +input_ids = tokenizer("ζ°΄εœ¨ι›ΆεΊ¦ζ—ΆδΌš[MASK]", return_tensors="pt").to(model.device) +outputs = model(**input_ids) +decoded = tokenizer.batch_decode(outputs.logits.argmax(-1), skip_special_tokens=True) +print(decoded) +``` + + + + +```bash +echo -e "ζ°΄εœ¨ι›ΆεΊ¦ζ—ΆδΌš[MASK]" | transformers-cli run --task fill-mask --model junnyu/roformer_chinese_base --device 0 +``` + + + + +## Notes + +- The current RoFormer implementation is an encoder-only model. The original code can be found in the [ZhuiyiTechnology/roformer](https://github.com/ZhuiyiTechnology/roformer) repository. ## RoFormerConfig diff --git a/docs/source/en/model_doc/swinv2.md b/docs/source/en/model_doc/swinv2.md index a709af9712e..0f71023e382 100644 --- a/docs/source/en/model_doc/swinv2.md +++ b/docs/source/en/model_doc/swinv2.md @@ -14,37 +14,74 @@ rendered properly in your Markdown viewer. --> -# Swin Transformer V2 - -
-PyTorch +
+
+ PyTorch +
-## Overview +# Swin Transformer V2 -The Swin Transformer V2 model was proposed in [Swin Transformer V2: Scaling Up Capacity and Resolution](https://arxiv.org/abs/2111.09883) by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo. +[Swin Transformer V2](https://huggingface.co/papers/2111.09883) is a 3B parameter model that focuses on how to scale a vision model to billions of parameters. It introduces techniques like residual-post-norm combined with cosine attention for improved training stability, log-spaced continuous position bias to better handle varying image resolutions between pre-training and fine-tuning, and a new pre-training method (SimMIM) to reduce the need for large amounts of labeled data. These improvements enable efficiently training very large models (up to 3 billion parameters) capable of processing high-resolution images. -The abstract from the paper is the following: +You can find official Swin Transformer V2 checkpoints under the [Microsoft](https://huggingface.co/microsoft?search_models=swinv2) organization. -*Large-scale NLP models have been shown to significantly improve the performance on language tasks with no signs of saturation. They also demonstrate amazing few-shot capabilities like that of human beings. This paper aims to explore large-scale models in computer vision. We tackle three major issues in training and application of large vision models, including training instability, resolution gaps between pre-training and fine-tuning, and hunger on labelled data. Three main techniques are proposed: 1) a residual-post-norm method combined with cosine attention to improve training stability; 2) A log-spaced continuous position bias method to effectively transfer models pre-trained using low-resolution images to downstream tasks with high-resolution inputs; 3) A self-supervised pre-training method, SimMIM, to reduce the needs of vast labeled images. Through these techniques, this paper successfully trained a 3 billion-parameter Swin Transformer V2 model, which is the largest dense vision model to date, and makes it capable of training with images of up to 1,536Γ—1,536 resolution. It set new performance records on 4 representative vision tasks, including ImageNet-V2 image classification, COCO object detection, ADE20K semantic segmentation, and Kinetics-400 video action classification. Also note our training is much more efficient than that in Google's billion-level visual models, which consumes 40 times less labelled data and 40 times less training time.* +> [!TIP] +> Click on the Swin Transformer V2 models in the right sidebar for more examples of how to apply Swin Transformer V2 to vision tasks. -This model was contributed by [nandwalritik](https://huggingface.co/nandwalritik). -The original code can be found [here](https://github.com/microsoft/Swin-Transformer). + + -## Resources +```py +import torch +from transformers import pipeline -A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Swin Transformer v2. +pipeline = pipeline( + task="image-classification", + model="microsoft/swinv2-tiny-patch4-window8-256", + torch_dtype=torch.float16, + device=0 +) +pipeline(images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg") +``` - + -- [`Swinv2ForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb). -- See also: [Image classification task guide](../tasks/image_classification) + -Besides that: +```py +import torch +import requests +from PIL import Image +from transformers import AutoModelForImageClassification, AutoImageProcessor -- [`Swinv2ForMaskedImageModeling`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining). +image_processor = AutoImageProcessor.from_pretrained( + "microsoft/swinv2-tiny-patch4-window8-256", +) +model = AutoModelForImageClassification.from_pretrained( + "microsoft/swinv2-tiny-patch4-window8-256", + device_map="auto" +) -If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" +image = Image.open(requests.get(url, stream=True).raw) +inputs = image_processor(image, return_tensors="pt").to(model.device) + +with torch.no_grad(): + logits = model(**inputs).logits + +predicted_class_id = logits.argmax(dim=-1).item() +predicted_class_label = model.config.id2label[predicted_class_id] +print(f"The predicted class label is: {predicted_class_label}") +``` + + + + +## Notes + +- Swin Transformer V2 can pad the inputs for any input height and width divisible by `32`. +- Swin Transformer V2 can be used as a [backbone](../backbones). When `output_hidden_states = True`, it outputs both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`. ## Swinv2Config diff --git a/docs/source/en/model_doc/zoedepth.md b/docs/source/en/model_doc/zoedepth.md index fefadfba6aa..59bc483d8cf 100644 --- a/docs/source/en/model_doc/zoedepth.md +++ b/docs/source/en/model_doc/zoedepth.md @@ -14,100 +14,101 @@ rendered properly in your Markdown viewer. --> -# ZoeDepth -
-PyTorch +
+
+ PyTorch +
-## Overview +# ZoeDepth -The ZoeDepth model was proposed in [ZoeDepth: Zero-shot Transfer by Combining Relative and Metric Depth](https://arxiv.org/abs/2302.12288) by Shariq Farooq Bhat, Reiner Birkl, Diana Wofk, Peter Wonka, Matthias MΓΌller. ZoeDepth extends the [DPT](dpt) framework for metric (also called absolute) depth estimation. ZoeDepth is pre-trained on 12 datasets using relative depth and fine-tuned on two domains (NYU and KITTI) using metric depth. A lightweight head is used with a novel bin adjustment design called metric bins module for each domain. During inference, each input image is automatically routed to the appropriate head using a latent classifier. - -The abstract from the paper is the following: - -*This paper tackles the problem of depth estimation from a single image. Existing work either focuses on generalization performance disregarding metric scale, i.e. relative depth estimation, or state-of-the-art results on specific datasets, i.e. metric depth estimation. We propose the first approach that combines both worlds, leading to a model with excellent generalization performance while maintaining metric scale. Our flagship model, ZoeD-M12-NK, is pre-trained on 12 datasets using relative depth and fine-tuned on two datasets using metric depth. We use a lightweight head with a novel bin adjustment design called metric bins module for each domain. During inference, each input image is automatically routed to the appropriate head using a latent classifier. Our framework admits multiple configurations depending on the datasets used for relative depth pre-training and metric fine-tuning. Without pre-training, we can already significantly improve the state of the art (SOTA) on the NYU Depth v2 indoor dataset. Pre-training on twelve datasets and fine-tuning on the NYU Depth v2 indoor dataset, we can further improve SOTA for a total of 21% in terms of relative absolute error (REL). Finally, ZoeD-M12-NK is the first model that can jointly train on multiple datasets (NYU Depth v2 and KITTI) without a significant drop in performance and achieve unprecedented zero-shot generalization performance to eight unseen datasets from both indoor and outdoor domains.* +[ZoeDepth](https://huggingface.co/papers/2302.12288) is a depth estimation model that combines the generalization performance of relative depth estimation (how far objects are from each other) and metric depth estimation (precise depth measurement on metric scale) from a single image. It is pre-trained on 12 datasets using relative depth and 2 datasets (NYU Depth v2 and KITTI) for metric accuracy. A lightweight head with a metric bin module for each domain is used, and during inference, it automatically selects the appropriate head for each input image with a latent classifier. drawing - ZoeDepth architecture. Taken from the original paper. +You can find all the original ZoeDepth checkpoints under the [Intel](https://huggingface.co/Intel?search=zoedepth) organization. -This model was contributed by [nielsr](https://huggingface.co/nielsr). -The original code can be found [here](https://github.com/isl-org/ZoeDepth). +The example below demonstrates how to estimate depth with [`Pipeline`] or the [`AutoModel`] class. -## Usage tips + + -- ZoeDepth is an absolute (also called metric) depth estimation model, unlike DPT which is a relative depth estimation model. This means that ZoeDepth is able to estimate depth in metric units like meters. +```py +import requests +import torch +from transformers import pipeline +from PIL import Image -The easiest to perform inference with ZoeDepth is by leveraging the [pipeline API](../main_classes/pipelines.md): - -```python ->>> from transformers import pipeline ->>> from PIL import Image ->>> import requests - ->>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" ->>> image = Image.open(requests.get(url, stream=True).raw) - ->>> pipe = pipeline(task="depth-estimation", model="Intel/zoedepth-nyu-kitti") ->>> result = pipe(image) ->>> depth = result["depth"] +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" +image = Image.open(requests.get(url, stream=True).raw) +pipeline = pipeline( + task="depth-estimation", + model="Intel/zoedepth-nyu-kitti", + torch_dtype=torch.float16, + device=0 +) +results = pipeline(image) +results["depth"] ``` -Alternatively, one can also perform inference using the classes: + + -```python ->>> from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation ->>> import torch ->>> import numpy as np ->>> from PIL import Image ->>> import requests +```py +import torch +import requests +from PIL import Image +from transformers import AutoModelForDepthEstimation, AutoImageProcessor ->>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" ->>> image = Image.open(requests.get(url, stream=True).raw) +image_processor = AutoImageProcessor.from_pretrained( + "Intel/zoedepth-nyu-kitti" +) +model = AutoModelForDepthEstimation.from_pretrained( + "Intel/zoedepth-nyu-kitti", + device_map="auto" +) +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" +image = Image.open(requests.get(url, stream=True).raw) +inputs = image_processor(image, return_tensors="pt").to("cuda") ->>> image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti") ->>> model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti") +with torch.no_grad(): + outputs = model(inputs) ->>> # prepare image for the model ->>> inputs = image_processor(images=image, return_tensors="pt") +# interpolate to original size and visualize the prediction +## ZoeDepth dynamically pads the input image, so pass the original image size as argument +## to `post_process_depth_estimation` to remove the padding and resize to original dimensions. +post_processed_output = image_processor.post_process_depth_estimation( + outputs, + source_sizes=[(image.height, image.width)], +) ->>> with torch.no_grad(): -... outputs = model(inputs) - ->>> # interpolate to original size and visualize the prediction ->>> ## ZoeDepth dynamically pads the input image. Thus we pass the original image size as argument ->>> ## to `post_process_depth_estimation` to remove the padding and resize to original dimensions. ->>> post_processed_output = image_processor.post_process_depth_estimation( -... outputs, -... source_sizes=[(image.height, image.width)], -... ) - ->>> predicted_depth = post_processed_output[0]["predicted_depth"] ->>> depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min()) ->>> depth = depth.detach().cpu().numpy() * 255 ->>> depth = Image.fromarray(depth.astype("uint8")) +predicted_depth = post_processed_output[0]["predicted_depth"] +depth = (predicted_depth - predicted_depth.min()) / (predicted_depth.max() - predicted_depth.min()) +depth = depth.detach().cpu().numpy() * 255 +Image.fromarray(depth.astype("uint8")) ``` - -

In the original implementation ZoeDepth model performs inference on both the original and flipped images and averages out the results. The post_process_depth_estimation function can handle this for us by passing the flipped outputs to the optional outputs_flipped argument:

-
>>> with torch.no_grad():   
-...     outputs = model(pixel_values)
-...     outputs_flipped = model(pixel_values=torch.flip(inputs.pixel_values, dims=[3]))
->>> post_processed_output = image_processor.post_process_depth_estimation(
-...     outputs,
-...     source_sizes=[(image.height, image.width)],
-...     outputs_flipped=outputs_flipped,
-... )
-
-
+
+
+## Notes + +- In the [original implementation](https://github.com/isl-org/ZoeDepth/blob/edb6daf45458569e24f50250ef1ed08c015f17a7/zoedepth/models/depth_model.py#L131) ZoeDepth performs inference on both the original and flipped images and averages the results. The `post_process_depth_estimation` function handles this by passing the flipped outputs to the optional `outputs_flipped` argument as shown below. + ```py + with torch.no_grad(): + outputs = model(pixel_values) + outputs_flipped = model(pixel_values=torch.flip(inputs.pixel_values, dims=[3])) + post_processed_output = image_processor.post_process_depth_estimation( + outputs, + source_sizes=[(image.height, image.width)], + outputs_flipped=outputs_flipped, + ) + ``` + ## Resources - -A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ZoeDepth. - -- A demo notebook regarding inference with ZoeDepth models can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/ZoeDepth). 🌎 +- Refer to this [notebook](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/ZoeDepth) for an inference example. ## ZoeDepthConfig diff --git a/docs/source/en/modular_transformers.md b/docs/source/en/modular_transformers.md index badeab0214a..84d365f9aad 100644 --- a/docs/source/en/modular_transformers.md +++ b/docs/source/en/modular_transformers.md @@ -243,13 +243,7 @@ class Olmo2Attention(OlmoAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md new file mode 100644 index 00000000000..fc20c08f9e6 --- /dev/null +++ b/docs/source/en/reference/environment_variables.md @@ -0,0 +1,58 @@ + + +# Environment Variables + +## HF_ENABLE_PARALLEL_LOADING + +By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups around ~50%. + +Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`. + +e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~30s with this enabled vs ~55s without it. + +Profile before committing to using this environment variable, this will not produce speed ups for smaller models. + +```py +import os + +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" + +from transformers import pipeline + +model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +``` + +## HF_PARALLEL_LOADING_WORKERS + +Determines how many threads should be used when parallel loading is enabled. Default is `8`. + +If the number of files that are being loaded is less than the number of threads specified, the number that is actually spawned will be equal to the number of files. + +e.g. If you specify 8 workers, and there are only 2 files, only 2 workers will be spawned. + +Tune as you see fit. + +```py +import os + +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" +os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4" + +from transformers import pipeline + +model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +``` diff --git a/docs/source/ja/model_doc/auto.md b/docs/source/ja/model_doc/auto.md index 492c46c79ea..27030a264f5 100644 --- a/docs/source/ja/model_doc/auto.md +++ b/docs/source/ja/model_doc/auto.md @@ -372,3 +372,10 @@ AutoModel.register(NewModelConfig, NewModel) ### AutoModelForImageTextToText [[autodoc]] AutoModelForImageTextToText + +## Time Series + +### AutoModelForTimeSeriesPrediction + +[[autodoc]] AutoModelForTimeSeriesPrediction + diff --git a/docs/source/ko/model_doc/auto.md b/docs/source/ko/model_doc/auto.md index cda00adc33a..45c2f917a42 100644 --- a/docs/source/ko/model_doc/auto.md +++ b/docs/source/ko/model_doc/auto.md @@ -373,3 +373,10 @@ AutoModel.register(NewModelConfig, NewModel) ### FlaxAutoModelForVision2Seq[[transformers.FlaxAutoModelForVision2Seq]] [[autodoc]] FlaxAutoModelForVision2Seq + +## Time Series + +### AutoModelForTimeSeriesPrediction[[transformers.AutoModelForTimeSeriesPrediction]] + +[[autodoc]] AutoModelForTimeSeriesPrediction + diff --git a/examples/metrics-monitoring/README.md b/examples/metrics-monitoring/README.md new file mode 100644 index 00000000000..64ef1160c66 --- /dev/null +++ b/examples/metrics-monitoring/README.md @@ -0,0 +1,4 @@ +# Metrics Monitoring + +## Continuous Batching Metrics in Transformers + diff --git a/examples/metrics-monitoring/continuous-batching-dashboard.json b/examples/metrics-monitoring/continuous-batching-dashboard.json new file mode 100644 index 00000000000..e0a293d0629 --- /dev/null +++ b/examples/metrics-monitoring/continuous-batching-dashboard.json @@ -0,0 +1,974 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 2, + "links": [], + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "description": "Memory usage of the PagedAttentionCache", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "max": 10737418240, + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "yellow", + "value": 5368709120 + }, + { + "color": "red", + "value": 8589934592 + } + ] + }, + "unit": "bytes" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 6, + "x": 0, + "y": 0 + }, + "id": 2, + "options": { + "minVizHeight": 75, + "minVizWidth": 75, + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showThresholdLabels": false, + "showThresholdMarkers": true, + "sizing": "auto" + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "kv_cache_memory_bytes", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "KV Cache Memory Usage", + "transparent": true, + "type": "gauge" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "dark-blue" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 6, + "x": 6, + "y": 0 + }, + "id": 13, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "disableTextWrap": false, + "editorMode": "builder", + "expr": "active_requests_count", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Active Requests", + "transparent": true, + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "dark-orange" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 6, + "x": 12, + "y": 0 + }, + "id": 14, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "disableTextWrap": false, + "editorMode": "builder", + "expr": "waiting_requests_count", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Waiting Requests", + "transparent": true, + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "description": "Ratio of decode tokens to prefill tokens in a batch", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "blue" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 6, + "x": 18, + "y": 0 + }, + "id": 6, + "options": { + "colorMode": "value", + "graphMode": "none", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "decode_prefill_ratio", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Decode/Prefill Ratio", + "transparent": true, + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "id": 10, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(decode_tokens_processed_total[$__rate_interval])", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Decode tokens throupught tok/s", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 8 + }, + "id": 11, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(prefill_tokens_processed_total[$__rate_interval])", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Prefill rate tok/s", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 16 + }, + "id": 9, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))", + "legendFormat": "p95", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))", + "hide": false, + "instant": false, + "legendFormat": "p99", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))", + "hide": false, + "instant": false, + "legendFormat": "p50", + "range": true, + "refId": "C" + } + ], + "title": "Batch fill percentage percentiles", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "description": "KV Cache Memory Usage Over Time", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 20, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "bytes" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 16 + }, + "id": 4, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "kv_cache_memory_bytes", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "Used memory", + "range": true, + "refId": "A", + "useBackend": false + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "kv_cache_free_memory_bytes", + "fullMetaSearch": false, + "hide": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "free memory", + "range": true, + "refId": "B", + "useBackend": false + } + ], + "title": "KV Cache Memory Usage Over Time", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 24 + }, + "id": 8, + "options": { + "displayMode": "gradient", + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "maxVizHeight": 300, + "minVizHeight": 10, + "minVizWidth": 0, + "namePlacement": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showUnfilled": true, + "sizing": "auto", + "valueMode": "color" + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "histogram_quantile(0.95, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "p95", + "range": true, + "refId": "A", + "useBackend": false + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "histogram_quantile(0.5, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))", + "fullMetaSearch": false, + "hide": false, + "includeNullMetadata": true, + "legendFormat": "p50", + "range": true, + "refId": "B", + "useBackend": false + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "histogram_quantile(0.99, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))", + "fullMetaSearch": false, + "hide": false, + "includeNullMetadata": false, + "instant": false, + "legendFormat": "p99", + "range": true, + "refId": "C", + "useBackend": false + } + ], + "title": "Time to First Token (TTFT)", + "type": "bargauge" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "id": 12, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))", + "hide": false, + "instant": false, + "legendFormat": "p95", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))", + "hide": false, + "instant": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Request latency percentiles", + "type": "timeseries" + } + ], + "preload": false, + "refresh": "5s", + "schemaVersion": 41, + "tags": [], + "templating": { + "list": [] + }, + "time": { + "from": "now-15m", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Transformers Continuous Batching Metrics", + "uid": "Lw6CTvVSz", + "version": 5 +} \ No newline at end of file diff --git a/examples/metrics-monitoring/docker-compose.yml b/examples/metrics-monitoring/docker-compose.yml new file mode 100644 index 00000000000..936f4a894ce --- /dev/null +++ b/examples/metrics-monitoring/docker-compose.yml @@ -0,0 +1,55 @@ +services: + memcached: + image: memcached:1.6.29 + container_name: memcached + ports: + - "11211:11211" + environment: + - MEMCACHED_MAX_MEMORY=64m # Set the maximum memory usage + - MEMCACHED_THREADS=4 # Number of threads to use + + prometheus: + image: prom/prometheus:latest + command: + - "--config.file=/etc/prometheus/prometheus.yml" + - --web.enable-otlp-receiver # Enable OTLP receiver + - --web.enable-remote-write-receiver + - --enable-feature=exemplar-storage + - --enable-feature=native-histograms + volumes: + - ./prometheus.yml:/etc/prometheus/prometheus.yml + ports: + - "9090:9090" + + tempo: + image: grafana/tempo:latest + command: [ "-config.file=/etc/tempo.yaml" ] + volumes: + - ./tempo.yaml:/etc/tempo.yaml + ports: + - "14268:14268" # jaeger ingest + - "3200:3200" # tempo + - "9095:9095" # tempo grpc + - "4317:4317" # otlp grpc + - "4318:4318" # otlp http + - "9411:9411" # zipkin + depends_on: + - memcached + + grafana: + image: grafana/grafana:latest + volumes: + - ./continuous-batching-dashboard.json:/etc/grafana/provisioning/dashboards/continuous-batching-dashboard.json + - ./grafana-dashboard.yaml:/etc/grafana/provisioning/dashboards/grafana-dashboard.yaml + - ./grafana-datasources.yaml:/etc/grafana/provisioning/datasources/datasources.yaml + environment: + - GF_AUTH_ANONYMOUS_ENABLED=true + - GF_AUTH_ANONYMOUS_ORG_ROLE=Admin + - GF_AUTH_DISABLE_LOGIN_FORM=true + - GF_FEATURE_TOGGLES_ENABLE=traceqlEditor metricsSummary + - GF_INSTALL_PLUGINS=https://storage.googleapis.com/integration-artifacts/grafana-exploretraces-app/grafana-exploretraces-app-latest.zip;grafana-traces-app + ports: + - "3000:3000" + depends_on: + - prometheus + - tempo diff --git a/examples/metrics-monitoring/grafana-dashboard.yaml b/examples/metrics-monitoring/grafana-dashboard.yaml new file mode 100644 index 00000000000..6dd396d00e1 --- /dev/null +++ b/examples/metrics-monitoring/grafana-dashboard.yaml @@ -0,0 +1,11 @@ +apiVersion: 1 + +providers: + - name: 'Transformers Dashboards' + orgId: 1 + folder: 'Transformers' + type: file + disableDeletion: false + editable: true + options: + path: /etc/grafana/provisioning/dashboards diff --git a/examples/metrics-monitoring/grafana-datasources.yaml b/examples/metrics-monitoring/grafana-datasources.yaml new file mode 100644 index 00000000000..e3f2e78bece --- /dev/null +++ b/examples/metrics-monitoring/grafana-datasources.yaml @@ -0,0 +1,14 @@ +apiVersion: 1 + +datasources: + - name: Prometheus + type: prometheus + access: proxy + url: http://prometheus:9090 + isDefault: true + + - name: Tempo + type: tempo + access: proxy + url: http://tempo:3200 + uid: tempo diff --git a/examples/metrics-monitoring/metrics_example.py b/examples/metrics-monitoring/metrics_example.py new file mode 100644 index 00000000000..df3551b68d4 --- /dev/null +++ b/examples/metrics-monitoring/metrics_example.py @@ -0,0 +1,48 @@ +# Example usage of the trace and attach_tracer decorators + +from transformers.utils.metrics import attach_tracer, traced + + +@attach_tracer() +class ExampleClass: + def __init__(self, name): + # The attach_tracer decorator has already created self.tracer for us + self.name = name + + @traced # This method will use the tracer from the class instance + def process_data(self, data): + # This method is traced and can use self.tracer + return f"Processed {data} with {self.name}" + + @traced(span_name="custom_operation") # With custom span name + def special_operation(self, value): + # Also traced, with a custom span name + return value * 2 + + @traced( + additional_attributes=[ + ("name", "object.name", lambda x: x.upper()), # Using a transform function + ("name", "object.fixed_value", "static_value"), # Using a fixed value + ] + ) + def operation_with_attributes(self): + # This will add the specified attributes to the span + return "Operation completed" + + +# For functions without a class, the traced decorator still works +@traced +def standalone_function(arg1, arg2): + # For functions, a tracer is created based on the module name + return arg1 + arg2 + + +# Usage: +if __name__ == "__main__": + # With OpenTelemetry configured, these will produce traces + example = ExampleClass("test_object") + example.process_data("sample") + example.special_operation(42) + example.operation_with_attributes() + + result = standalone_function(1, 2) diff --git a/examples/metrics-monitoring/prometheus.yml b/examples/metrics-monitoring/prometheus.yml new file mode 100644 index 00000000000..6c578ad89f5 --- /dev/null +++ b/examples/metrics-monitoring/prometheus.yml @@ -0,0 +1,3 @@ +global: + scrape_interval: 15s + diff --git a/examples/metrics-monitoring/tempo.yaml b/examples/metrics-monitoring/tempo.yaml new file mode 100644 index 00000000000..353b83e1ccc --- /dev/null +++ b/examples/metrics-monitoring/tempo.yaml @@ -0,0 +1,90 @@ +stream_over_http_enabled: true +server: + http_listen_port: 3200 + log_level: info + + +cache: + background: + writeback_goroutines: 5 + caches: + - roles: + - frontend-search + memcached: + addresses: dns+memcached:11211 + +query_frontend: + search: + duration_slo: 5s + throughput_bytes_slo: 1.073741824e+09 + metadata_slo: + duration_slo: 5s + throughput_bytes_slo: 1.073741824e+09 + trace_by_id: + duration_slo: 100ms + metrics: + max_duration: 200h # maximum duration of a metrics query, increase for local setups + query_backend_after: 5m + duration_slo: 5s + throughput_bytes_slo: 1.073741824e+09 + +distributor: + receivers: # this configuration will listen on all ports and protocols that tempo is capable of. + jaeger: # the receives all come from the OpenTelemetry collector. more configuration information can + protocols: # be found there: https://github.com/open-telemetry/opentelemetry-collector/tree/main/receiver + thrift_http: # + endpoint: "tempo:14268" # for a production deployment you should only enable the receivers you need! + grpc: + endpoint: "tempo:14250" + thrift_binary: + endpoint: "tempo:6832" + thrift_compact: + endpoint: "tempo:6831" + zipkin: + endpoint: "tempo:9411" + otlp: + protocols: + grpc: + endpoint: "tempo:4317" + http: + endpoint: "tempo:4318" + opencensus: + endpoint: "tempo:55678" + +ingester: + max_block_duration: 5m # cut the headblock when this much time passes. this is being set for demo purposes and should probably be left alone normally + +compactor: + compaction: + block_retention: 720h # overall Tempo trace retention. set for demo purposes + +metrics_generator: + registry: + external_labels: + source: tempo + cluster: docker-compose + storage: + path: /var/tempo/generator/wal + remote_write: + - url: http://prometheus:9090/api/v1/write + send_exemplars: true + traces_storage: + path: /var/tempo/generator/traces + processor: + local_blocks: + filter_server_spans: false + flush_to_storage: true + +storage: + trace: + backend: local # backend configuration to use + wal: + path: /var/tempo/wal # where to store the wal locally + local: + path: /var/tempo/blocks + +overrides: + defaults: + metrics_generator: + processors: [service-graphs, span-metrics, local-blocks] # enables metrics generator + generate_native_histograms: both diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py new file mode 100644 index 00000000000..9aaa836f7ba --- /dev/null +++ b/examples/pytorch/continuous_batching.py @@ -0,0 +1,109 @@ +import time + +import datasets +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + + +torch.set_float32_matmul_precision("high") + +model_id = "meta-llama/Llama-3.2-3b-Instruct" +model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" +).eval() +tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + +generation_config = GenerationConfig( + max_new_tokens=512, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + num_blocks=2048, + block_size=128, + do_sample=True, + max_batch_tokens=1024, # Maximum number of tokens to process in a single batch + scheduler="prefill_first", +) + +train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") + +# --- Example 1: Simple Version using generate_batch --- +print("--- Running CB Generation Example ---") + + +def tokenize_function(examples): + return tokenizer(examples["question"]) + + +tokenized_datasets = train_dataset.map(tokenize_function, batched=True) +simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] + +start_time_simple = time.time() +# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True) +batch_outputs = model.generate_batch( + inputs=simple_batch_inputs, + generation_config=generation_config, +) +end_time_simple = time.time() + +for request in batch_outputs: + input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) + try: + output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) + except Exception as e: + print(f"Decoding failed for request {request}: {e}") + output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) + if len(output_text) > 0: + print("-" * 20) + print(f"{request} Input: {input_text}") + print(f"{request} Output: {output_text}") + else: + print("", end="\r\r\r\r") +print("-" * 20) +print("--- Finished CB Generation Example ---\n\n") + + +print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds") + + +# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version + +# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512) +# simple_batch_inputs = list(tokenized_test_prompts["input_ids"]) + +# def tokenize_function(examples): +# # Truncate to avoid overly long prompts exceeding max context length +# return tokenizer(examples["question"], padding=True, truncation=True, max_length=512) + + +# tokenized_datasets = train_dataset.map(tokenize_function, batched=True) +# simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] + + +# model.config.attn_implementation = "sdpa" +# start_time_simple = time.time() +# batch_size = 64 +# full_outputs = [] +# from tqdm import tqdm + +# for i in tqdm(range(0, len(simple_batch_inputs)-batch_size, batch_size)): +# outputs = model.generate( +# torch.tensor(simple_batch_inputs[i:i+batch_size], device=model.device), +# generation_config=GenerationConfig( +# max_new_tokens=16, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id +# ), +# ) +# full_outputs.extend(outputs.tolist()) + +# end_time_simple = time.time() +# print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds") + +# print("\nResults from simple generate_batch:") +# for i, request in enumerate(full_outputs): +# output_text = tokenizer.decode(request, skip_special_tokens=False) +# print("-" * 20) +# print(f" Output: {output_text}") +# print("-" * 20) +# print("--- Finished Simple Batch Generation Example ---\n\n") diff --git a/setup.py b/setup.py index 52024f77c12..2b74308081e 100644 --- a/setup.py +++ b/setup.py @@ -201,6 +201,9 @@ _deps = [ "pytest-rich", "libcst", "rich", + "opentelemetry-api", + "opentelemetry-exporter-otlp", + "opentelemetry-sdk", ] @@ -435,6 +438,9 @@ extras["torchhub"] = deps_list( extras["benchmark"] = deps_list("optimum-benchmark") +# OpenTelemetry dependencies for metrics collection in continuous batching +extras["open-telemetry"] = deps_list("opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk") + # when modifying the following list, make sure to update src/transformers/dependency_versions_check.py install_requires = [ deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c0bd42f2e39..1a3ba7f8df8 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1716,6 +1716,19 @@ class EncoderDecoderCache(Cache): self.self_attention_cache.batch_select_indices(indices) self.cross_attention_cache.batch_select_indices(indices) + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + return self.self_attention_cache.get_max_cache_shape() + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) + class HybridCache(Cache): """ @@ -1967,7 +1980,8 @@ class HybridChunkedCache(Cache): else: self.sliding_window = config.sliding_window self.max_cache_len = max_cache_len - self._sliding_window_max_len = min(self.sliding_window, max_cache_len) + # Sliding layers can't be larger than the overall max cache len + self.sliding_window = min(self.sliding_window, self.max_cache_len) self.max_batch_size = max_batch_size self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self._dtype = dtype @@ -1989,7 +2003,7 @@ class HybridChunkedCache(Cache): num_key_value_heads = key_states.shape[1] device = key_states.device global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim) + sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape @@ -2163,7 +2177,7 @@ class OffloadedHybridCache(HybridChunkedCache): device = key_states.device if self.is_sliding[layer_idx] else self.offload_device pin_memory = not self.is_sliding[layer_idx] global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) - sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self._sliding_window_max_len, self.head_dim) + sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim) # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape @@ -2231,7 +2245,7 @@ class OffloadedHybridCache(HybridChunkedCache): def _prefetch_layer_in_context(self, layer_idx: int) -> None: """Performs the actual copy of the layer to device cache.""" - if len(self.key_cache) >= layer_idx: + if len(self.key_cache) > layer_idx: self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True) self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True) # The layer was not yet initialized diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index 5c9bd76bdb0..7ade958149a 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -13,9 +13,11 @@ # limitations under the License. +import copy import json import os import platform +import re import string import time import warnings @@ -25,7 +27,15 @@ from threading import Thread from typing import Optional import yaml +from huggingface_hub.utils import disable_progress_bars +from transformers import ( + AutoTokenizer, + GenerationConfig, + PreTrainedTokenizer, + TextIteratorStreamer, + logging, +) from transformers.utils import is_rich_available, is_torch_available from . import BaseTransformersCLICommand @@ -42,13 +52,7 @@ if is_rich_available(): if is_torch_available(): import torch - from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - BitsAndBytesConfig, - GenerationConfig, - TextIteratorStreamer, - ) + from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace) @@ -68,6 +72,7 @@ DEFAULT_EXAMPLES = { "numbers": {"text": "Count to 10 but skip every number ending with an 'e'"}, "birds": {"text": "Why aren't birds real?"}, "socks": {"text": "Why is it important to eat socks after meditating?"}, + "numbers2": {"text": "Which number is larger, 9.9 or 9.11?"}, } # Printed at the start of a chat session @@ -76,7 +81,7 @@ HELP_STRING_MINIMAL = """ **TRANSFORMERS CHAT INTERFACE** Chat interface to try out a model. Besides chatting with the model, here are some basic commands: -- **!help**: shows all available commands +- **!help**: shows all available commands (set generation settings, save chat, etc.) - **!status**: shows the current status of the model and generation settings - **!clear**: clears the current conversation and starts a new one - **!exit**: closes the interface @@ -140,6 +145,9 @@ class RichInterface: for i, outputs in enumerate(output_stream): if not outputs or i == 0: continue + # Escapes single words encased in <>, e.g. -> \, for proper rendering in Markdown. + # It only escapes single words that may have `_`, optionally following a `/` (e.g. ) + outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs) text += outputs # Render the accumulated text as Markdown # NOTE: this is a workaround for the rendering "unstandard markdown" @@ -224,6 +232,7 @@ class ChatArguments: system_prompt: Optional[str] = field(default=None, metadata={"help": "System prompt."}) save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history."}) examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."}) + verbose: bool = field(default=False, metadata={"help": "Whether to show runtime warnings in the chat interface."}) # Generation settings generation_config: Optional[str] = field( @@ -246,7 +255,9 @@ class ChatArguments: repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."}) eos_tokens: Optional[str] = field( default=None, - metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated."}, + metadata={ + "help": "EOS tokens (text format) to stop the generation. If multiple they should be comma separated." + }, ) eos_token_ids: Optional[str] = field( default=None, @@ -431,6 +442,9 @@ class ChatCommand(BaseTransformersCLICommand): # 2. b. strings should be quoted def is_number(s: str) -> bool: + # handle negative numbers + if s.startswith("-"): + s = s[1:] return s.replace(".", "", 1).isdigit() generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()} @@ -464,16 +478,19 @@ class ChatCommand(BaseTransformersCLICommand): return processed_generate_flags def get_generation_parameterization( - self, args: ChatArguments, tokenizer: AutoTokenizer + self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel ) -> tuple[GenerationConfig, dict]: """ Returns a GenerationConfig object holding the generation parameters for the CLI command. """ - # No generation config arg provided -> use base generation config, apply CLI defaults + # No generation config arg provided -> use default generation config, apply CLI defaults if args.generation_config is None: - generation_config = GenerationConfig() + # We start off from the checkpoint's generation config + generation_config = copy.deepcopy(model.generation_config) # Apply deprecated CLI args on top of the default generation config - pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) + pad_token_id, eos_token_ids = self.parse_eos_tokens( + tokenizer, generation_config, args.eos_tokens, args.eos_token_ids + ) deprecated_kwargs = { "max_new_tokens": args.max_new_tokens, "do_sample": args.do_sample, @@ -504,13 +521,16 @@ class ChatCommand(BaseTransformersCLICommand): @staticmethod def parse_eos_tokens( - tokenizer: AutoTokenizer, eos_tokens: Optional[str], eos_token_ids: Optional[str] + tokenizer: PreTrainedTokenizer, + generation_config: GenerationConfig, + eos_tokens: Optional[str], + eos_token_ids: Optional[str], ) -> tuple[int, list[int]]: """Retrieves the pad token ID and all possible EOS token IDs.""" - if tokenizer.pad_token_id is None: - pad_token_id = tokenizer.eos_token_id + if generation_config.pad_token_id is None: + pad_token_id = generation_config.eos_token_id else: - pad_token_id = tokenizer.pad_token_id + pad_token_id = generation_config.pad_token_id all_eos_token_ids = [] @@ -521,7 +541,7 @@ class ChatCommand(BaseTransformersCLICommand): all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")]) if len(all_eos_token_ids) == 0: - all_eos_token_ids.append(tokenizer.eos_token_id) + all_eos_token_ids.append(generation_config.eos_token_id) return pad_token_id, all_eos_token_ids @@ -547,7 +567,7 @@ class ChatCommand(BaseTransformersCLICommand): return quantization_config - def load_model_and_tokenizer(self, args: ChatArguments) -> tuple[AutoModelForCausalLM, AutoTokenizer]: + def load_model_and_tokenizer(self, args: ChatArguments) -> tuple["AutoModelForCausalLM", AutoTokenizer]: tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path_positional, revision=args.model_revision, @@ -588,6 +608,7 @@ class ChatCommand(BaseTransformersCLICommand): Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the generation config (e.g. set a new flag). """ + valid_command = True if user_input == "!clear": chat = self.clear_chat_history(args.system_prompt) @@ -649,10 +670,11 @@ class ChatCommand(BaseTransformersCLICommand): ) else: + valid_command = False interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red") interface.print_help() - return chat, generation_config, model_kwargs + return chat, valid_command, generation_config, model_kwargs # ----------------------------------------------------------------------------------------------------------------- # Main logic @@ -676,7 +698,12 @@ class ChatCommand(BaseTransformersCLICommand): model, tokenizer = self.load_model_and_tokenizer(args) generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) - generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer) + generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model) + + # if not verbose -> disable warnings, progress bars, etc in the chat interface + if not args.verbose: + logging.set_verbosity_error() + disable_progress_bars() interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user) interface.clear() @@ -694,7 +721,7 @@ class ChatCommand(BaseTransformersCLICommand): if user_input == "!exit": break else: - chat, generation_config, model_kwargs = self.handle_non_exit_user_commands( + chat, valid_command, generation_config, model_kwargs = self.handle_non_exit_user_commands( user_input=user_input, args=args, interface=interface, @@ -704,7 +731,7 @@ class ChatCommand(BaseTransformersCLICommand): chat=chat, ) # `!example` sends a user message to the model - if not user_input.startswith("!example"): + if not valid_command or not user_input.startswith("!example"): continue else: chat.append({"role": "user", "content": user_input}) diff --git a/src/transformers/commands/env.py b/src/transformers/commands/env.py index 4721f1ccf66..983a858cd95 100644 --- a/src/transformers/commands/env.py +++ b/src/transformers/commands/env.py @@ -32,6 +32,7 @@ from ..utils import ( is_torch_available, is_torch_hpu_available, is_torch_npu_available, + is_torch_xpu_available, ) from . import BaseTransformersCLICommand @@ -89,15 +90,25 @@ class EnvironmentCommand(BaseTransformersCLICommand): pt_version = "not installed" pt_cuda_available = "NA" + pt_accelerator = "NA" if is_torch_available(): import torch pt_version = torch.__version__ pt_cuda_available = torch.cuda.is_available() - pt_xpu_available = torch.xpu.is_available() + pt_xpu_available = is_torch_xpu_available() pt_npu_available = is_torch_npu_available() pt_hpu_available = is_torch_hpu_available() + if pt_cuda_available: + pt_accelerator = "CUDA" + elif pt_xpu_available: + pt_accelerator = "XPU" + elif pt_npu_available: + pt_accelerator = "NPU" + elif pt_hpu_available: + pt_accelerator = "HPU" + tf_version = "not installed" tf_cuda_available = "NA" if is_tf_available(): @@ -141,7 +152,7 @@ class EnvironmentCommand(BaseTransformersCLICommand): "Accelerate version": f"{accelerate_version}", "Accelerate config": f"{accelerate_config_str}", "DeepSpeed version": f"{deepspeed_version}", - "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "PyTorch version (accelerator?)": f"{pt_version} ({pt_accelerator})", "Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})", "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})", "Jax version": f"{jax_version}", diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 6e75fbfb54a..205a7dde8f2 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -28,8 +28,6 @@ from .modeling_gguf_pytorch_utils import load_gguf_checkpoint from .utils import ( CONFIG_NAME, PushToHubMixin, - add_model_info_to_auto_map, - add_model_info_to_custom_pipelines, cached_file, copy_func, download_url, @@ -214,7 +212,7 @@ class PretrainedConfig(PushToHubMixin): # Attributes with defaults self.return_dict = kwargs.pop("return_dict", True) self.output_hidden_states = kwargs.pop("output_hidden_states", False) - self.output_attentions = kwargs.pop("output_attentions", False) + self._output_attentions = kwargs.pop("output_attentions", False) self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models self.use_bfloat16 = kwargs.pop("use_bfloat16", False) @@ -331,6 +329,22 @@ class PretrainedConfig(PushToHubMixin): def name_or_path(self, value): self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding) + @property + def output_attentions(self): + """ + `bool`: Whether or not the model should returns all attentions. + """ + return self._output_attentions + + @output_attentions.setter + def output_attentions(self, value): + if self._attn_implementation != "eager": + raise ValueError( + "The `output_attentions` attribute is not supported when using the `attn_implementation` set to " + f"{self._attn_implementation}. Please set it to 'eager' instead." + ) + self._output_attentions = value + @property def use_return_dict(self) -> bool: """ @@ -697,15 +711,6 @@ class PretrainedConfig(PushToHubMixin): else: logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}") - if "auto_map" in config_dict and not is_local: - config_dict["auto_map"] = add_model_info_to_auto_map( - config_dict["auto_map"], pretrained_model_name_or_path - ) - if "custom_pipelines" in config_dict and not is_local: - config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( - config_dict["custom_pipelines"], pretrained_model_name_or_path - ) - # timm models are not saved with the model_type in the config file if "model_type" not in config_dict and is_timm_config_dict(config_dict): config_dict["model_type"] = "timm_wrapper" @@ -1004,6 +1009,8 @@ class PretrainedConfig(PushToHubMixin): if "_auto_class" in d: del d["_auto_class"] + if "_output_attentions" in d: + d["output_attentions"] = d.pop("_output_attentions") if "_commit_hash" in d: del d["_commit_hash"] if "_attn_implementation_internal" in d: @@ -1026,11 +1033,7 @@ class PretrainedConfig(PushToHubMixin): Register this class with a given auto class. This should only be used for custom configurations as the ones in the library are already mapped with `AutoConfig`. - - This API is experimental and may have some slight breaking changes in the next releases. - - Args: auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`): diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index c01f5bb388c..5c0ae6b772f 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -103,4 +103,7 @@ deps = { "pytest-rich": "pytest-rich", "libcst": "libcst", "rich": "rich", + "opentelemetry-api": "opentelemetry-api", + "opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp", + "opentelemetry-sdk": "opentelemetry-sdk", } diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index eec01749b65..660d0ac6d8d 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -667,7 +667,9 @@ def _raise_timeout_error(signum, frame): TIME_OUT_REMOTE_CODE = 15 -def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None): +def resolve_trust_remote_code( + trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None, upstream_repo=None +): """ Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading it. @@ -688,11 +690,25 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has Returns: The resolved `trust_remote_code` value. """ - # Originally, `trust_remote_code` was used to load models with custom code. - error_message = ( - error_message - or f"The repository `{model_name}` contains custom code which must be executed to correctly load the model." - ) + if error_message is None: + if upstream_repo is not None: + error_message = ( + f"The repository {model_name} references custom code contained in {upstream_repo} which " + f"must be executed to correctly load the model. You can inspect the repository " + f"content at https://hf.co/{upstream_repo} .\n" + ) + elif os.path.isdir(model_name): + error_message = ( + f"The repository {model_name} contains custom code which must be executed " + f"to correctly load the model. You can inspect the repository " + f"content at {os.path.abspath(model_name)} .\n" + ) + else: + error_message = ( + f"The repository {model_name} contains custom code which must be executed " + f"to correctly load the model. You can inspect the repository " + f"content at https://hf.co/{model_name} .\n" + ) if trust_remote_code is None: if has_local_code: diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index ca2a3b5fde3..51e882aefa8 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -29,8 +29,6 @@ from .utils import ( FEATURE_EXTRACTOR_NAME, PushToHubMixin, TensorType, - add_model_info_to_auto_map, - add_model_info_to_custom_pipelines, cached_file, copy_func, download_url, @@ -551,16 +549,6 @@ class FeatureExtractionMixin(PushToHubMixin): f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}" ) - if not is_local: - if "auto_map" in feature_extractor_dict: - feature_extractor_dict["auto_map"] = add_model_info_to_auto_map( - feature_extractor_dict["auto_map"], pretrained_model_name_or_path - ) - if "custom_pipelines" in feature_extractor_dict: - feature_extractor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( - feature_extractor_dict["custom_pipelines"], pretrained_model_name_or_path - ) - return feature_extractor_dict, kwargs @classmethod @@ -673,11 +661,7 @@ class FeatureExtractionMixin(PushToHubMixin): Register this class with a given auto class. This should only be used for custom feature extractors as the ones in the library are already mapped with `AutoFeatureExtractor`. - - This API is experimental and may have some slight breaking changes in the next releases. - - Args: auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`): diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index cf1fa3661e0..64ebfe6fc7c 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -97,6 +97,9 @@ else: "validate_stopping_criteria", "StopStringCriteria", ] + _import_structure["continuous_batching"] = [ + "ContinuousMixin", + ] _import_structure["utils"] = [ "GenerationMixin", "GreedySearchEncoderDecoderOutput", @@ -213,6 +216,7 @@ if TYPE_CHECKING: EarlyExitCandidateGenerator, PromptLookupCandidateGenerator, ) + from .continuous_batching import ContinuousMixin from .logits_process import ( AlternatingCodebooksLogitsProcessor, ClassifierFreeGuidanceLogitsProcessor, diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 9bfa5a64d77..99239b760d4 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -821,7 +821,7 @@ class GenerationConfig(PushToHubMixin): warning_message = ( f"The following generation flags are not valid and may be ignored: {attributes_with_issues}." ) - if logger.getEffectiveLevel() >= logging.WARNING: + if logging.get_verbosity() >= logging.WARNING: warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details." logger.warning(warning_message) logger.info(info_message) diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py new file mode 100644 index 00000000000..faafe2c6122 --- /dev/null +++ b/src/transformers/generation/continuous_batching.py @@ -0,0 +1,1444 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import queue +import statistics +import threading +import time +from abc import ABC, abstractmethod +from collections import deque +from dataclasses import dataclass, field +from enum import Enum +from functools import partial +from typing import Deque, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +from torch.profiler import profile, schedule, tensorboard_trace_handler +from tqdm import tqdm + +from ..cache_utils import Cache +from ..configuration_utils import PretrainedConfig +from ..generation.configuration_utils import GenerationConfig +from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced + + +class RequestStatus(Enum): + """Status of a generation request through its lifecycle.""" + + PENDING = "pending" + PREFILLING = "prefilling" + PREFILLING_SPLIT = "prefilling_split" + SPLIT_PENDING_REMAINDER = "split_pending_remainder" + DECODING = "decoding" + FINISHED = "finished" + FAILED = "failed" + + +# Setup your logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +@dataclass +class GenerationOutput: + """Tracks the output of a generation request. + + Attributes: + request_id (str): The ID of the generation request. + prompt_ids (List[int]): The IDs of the prompt tokens. + generated_tokens (List[int]): The generated tokens. + logprobs (List[float]): The log probabilities of the generated tokens. + error (Optional[str]): Any error message associated with the request. When None, the request was successful. + """ + + request_id: str + prompt_ids: List[int] = field(default_factory=list) + generated_tokens: List[int] = field(default_factory=list) + logprobs: List[float] = field(default_factory=list) + error: Optional[str] = None + status: RequestStatus = RequestStatus.PENDING + created_time: float = field(default_factory=time.time) + + +@dataclass +class RequestState: + """Tracks the state of a generation request through its lifecycle. + + Attributes: + status (RequestStatus): can be one of PENDING, PREFILLING, PREFILLING_SPLIT, + SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED + """ + + # Required fields + request_id: str + prompt_ids: Optional[List[int]] = None # the one being processed + full_prompt_ids: Optional[List[int]] = None # the full prompt + remaining_prompt_ids: List[int] = field(default_factory=list) # For split requests + static_outputs: List[int] = field(default_factory=list) + allocated_blocks: List[int] = field(default_factory=list) + position_offset: int = 0 # Current position in the sequence for position_ids + status: RequestStatus = RequestStatus.PENDING + max_new_tokens: int = 20 + eos_token_id: int = -1 + created_time: float = field(default_factory=time.time) + error: Optional[str] = None + + def current_len(self) -> int: + """Get the current length of the sequence (prompt + generated tokens).""" + return self.position_offset + + def generated_len(self) -> int: + """Get the number of tokens generated so far.""" + return len(self.static_outputs) + + @traced + def update_with_token(self, token_id: int) -> bool: + """Update the request with a newly generated token and check for completion. + + Args: + token_id: The token ID to add to the output sequence + + Returns: + bool: True if the request is now complete, False otherwise + """ + # Only update if we're in decoding state + if self.status != RequestStatus.DECODING: + return False + + is_eos = token_id == self.eos_token_id and self.eos_token_id != -1 + is_max_len = self.generated_len() >= self.max_new_tokens + + if is_eos or is_max_len: + self.status = RequestStatus.FINISHED + return True + return False + + def __repr__(self): + return f"RequestState(\n\trequest_id={self.request_id},\n\tstatus={self.status},\n\tout_tokens={self.generated_len()},\n\tquery_length={len(self.prompt_ids)}, \n\tremaining_tokens={len(self.remaining_prompt_ids)}, \n\tkv_length={self.position_offset}\n\tfull_prompt_lenght={len(self.full_prompt_ids)},\n\tallocated_blocks={self.allocated_blocks},\n\tgenerated_tokens={self.static_outputs}\n)" + + def to_generation_output(self): + """Convert the request state to a GenerationOutput object.""" + return GenerationOutput( + request_id=self.request_id, + prompt_ids=self.full_prompt_ids, + status=self.status, + generated_tokens=self.static_outputs, + logprobs=[], + error=self.error, + ) + + +@attach_tracer() +class PagedAttentionCache(Cache): + def __init__( + self, + config: PretrainedConfig, + generation_config: GenerationConfig, + device: torch.device, + dtype: torch.dtype = torch.float16, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + initial_prompt_shapes: Optional[List[List[int]]] = None, + ) -> None: + """Initialize a paged attention cache for efficient memory usage. + + Args: + config: Model configuration + generation_config: Generation configuration containing cache parameters + device: Device for the cache tensors + dtype: Data type for the cache tensors + layer_device_map: Optional mapping of layer indices to devices + initial_prompt_shapes: Optional sample prompts to help calculate optimal cache size + """ + # Extract model dimensions + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.num_hidden_layers = config.num_hidden_layers + + # Calculate optimal block size and number if not provided + num_blocks = getattr(generation_config, "num_blocks", None) + block_size = getattr(generation_config, "block_size", None) + if num_blocks is None or block_size is None: + logger.info("Calculating optimal block size and number...") + num_blocks, block_size = compute_optimal_blocks( + device, config, generation_config, initial_prompt_shapes or [], dtype, median_prefill_length=200 + ) + logger.info(f"Using calculated num_blocks={num_blocks}, block_size={block_size}") + + self.block_size = block_size + self.num_blocks = num_blocks + self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim) + + self.dtype = dtype + self.device = device + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + for idx in range(config.num_hidden_layers): + layer_device = layer_device_map[idx] if layer_device_map is not None else device + new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) + new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + # Block management data structures + self._free_blocks = deque(range(num_blocks)) + self._block_tables: Dict[str, List[int]] = {} + + @traced + def allocate_blocks(self, n_blocks: int, request_id: str) -> List[int]: + """Allocates n_blocks for a given request_id.""" + if len(self._free_blocks) < n_blocks: + return False + + allocated = [] + for _ in range(n_blocks): + allocated.append(self._free_blocks.popleft()) + + if request_id not in self._block_tables: + self._block_tables[request_id] = [] + self._block_tables[request_id].extend(allocated) + return allocated + + @traced + def free_blocks(self, request_id: str) -> None: + """Frees all blocks associated with a request_id.""" + if request_id in self._block_tables: + blocks_to_free = self._block_tables.pop(request_id) + self._free_blocks.extend(blocks_to_free) + else: + logger.warning(f"Attempted to free blocks for non-existent request_id: {request_id}") + + def get_num_free_blocks(self) -> int: + """Returns the number of free blocks available.""" + return len(self._free_blocks) + + def get_block_table(self, request_id: str) -> List[int]: + """Returns the block table for a request.""" + return self._block_tables.get(request_id, []) + + @traced + def _get_physical_indices(self, state: RequestState, logical_indices: List[int]) -> List[int]: + """ + Maps logical sequence indices to physical cache indices using the block table, using PyTorch. + + Args: + request_id: The request ID. + logical_indices: A list of logical indices. + + Returns: + A list of physical indices. + + Raises: + ValueError: If no block table is found for the request ID. + IndexError: If a logical index maps to a block index that is out of bounds. + """ + request_id = state.request_id + block_table = self._block_tables.get(request_id) + if not block_table: + raise ValueError(f"No block table found for request {request_id}") + + block_size = self.block_size + physical_indices = [] + + for idx in logical_indices: + block_idx = idx // block_size + block_offset = idx % block_size + + if block_idx >= len(block_table): + raise IndexError( + f"Logical index {idx} maps to block index {block_idx} which is out of bounds " + f"for request {request_id}" + ) + + physical_block_num = block_table[block_idx] + physical_index = physical_block_num * block_size + block_offset + physical_indices.append(physical_index) + + return physical_indices + + @traced + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + read_index, + write_index, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Reshape cache for easier indexing + total_slots = self.num_blocks * self.block_size + k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) + v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) + k_cache_flat[:, write_index, :] = key_states[0] + v_cache_flat[:, write_index, :] = value_states[0] + return k_cache_flat[None, :, read_index, :], v_cache_flat[None, :, read_index, :] + + +class Scheduler(ABC): + """ + Abstract base class for scheduling requests in the continuous batch processor. + It is expected that cache allocation and scheduling logic will be implemented in subclasses. + """ + + def __init__(self, cache: PagedAttentionCache): + self.active_requests: Dict[str, RequestState] = {} + self.waiting_requests: Dict[str, RequestState] = {} + self.waiting_requests_order: Deque[str] = deque() + self.cache = cache + + @abstractmethod + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + pass + + @abstractmethod + def schedule_batch(self, token_budget: int) -> List[RequestState]: + pass + + @traced + def has_pending_requests(self) -> bool: + """Check if there are requests ready to be processed.""" + return self.active_requests or self.waiting_requests + + @abstractmethod + def finish_request(self, state: RequestState): + """Finish processing a request and free its allocated blocks.""" + pass + + @traced + def get_active_request_static_outputs(self, request_id: str) -> List[int]: + if request_id in self.active_requests: + return self.active_requests[request_id].static_outputs + return [] + + +@attach_tracer() +class FIFOScheduler(Scheduler): + @traced + def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): + # 1. we check that the occupancy is less than the requested length + # 2. we allocate enough blocks to cover the requested length + current_len = state.current_len() + occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len + if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): + blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 + allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) + if not allocated: + return False + state.allocated_blocks.extend(allocated) + return True + + @traced(span_name="prepare_request") + def _prepare_request_for_processing( + self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: Set[str] + ): + """Prepare a request for processing in the current batch.""" + request_tokens = ( + state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids + ) + if len(request_tokens) < token_budget: + # Can process the entire prompt/remainder + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING + state.prompt_ids = state.remaining_prompt_ids + state.remaining_prompt_ids = [] + else: + # Need to split the request + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING_SPLIT + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING_SPLIT + state.remaining_prompt_ids = request_tokens[token_budget:] + state.prompt_ids = request_tokens[:token_budget] + + @traced + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + self.waiting_requests[state.request_id] = state + self.waiting_requests_order.append(state.request_id) + + @traced + def schedule_batch(self, token_budget: int) -> List[RequestState]: + priority_states: List[RequestState] = [] + second_priority_states: List[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + if state.status == RequestStatus.DECODING: + priority_states.append(state) + if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + second_priority_states.append(state) + + # Add waiting requests to second priority + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + request_ids_to_remove_from_waiting = set() + + for state in candidates: + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + @traced + def finish_request(self, state: RequestState): + request_id = state.request_id + self.cache.free_blocks(request_id) + if request_id in self.active_requests: + del self.active_requests[request_id] + + +@attach_tracer() +class PrefillFirstScheduler(Scheduler): + @traced + def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): + # 1. we check that the occupancy is less than the requested length + # 2. we allocate enough blocks to cover the requested length + current_len = state.current_len() + occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len + if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): + blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 + allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) + if not allocated: + return False + state.allocated_blocks.extend(allocated) + return True + + @traced(span_name="prepare_request") + def _prepare_request_for_processing( + self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: Set[str] + ): + """Prepare a request for processing in the current batch.""" + request_tokens = ( + state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids + ) + if len(request_tokens) < token_budget: + # Can process the entire prompt/remainder + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING + state.prompt_ids = state.remaining_prompt_ids + state.remaining_prompt_ids = [] + else: + # Need to split the request + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING_SPLIT + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING_SPLIT + state.remaining_prompt_ids = request_tokens[token_budget:] + state.prompt_ids = request_tokens[:token_budget] + + @traced + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + self.waiting_requests[state.request_id] = state + self.waiting_requests_order.append(state.request_id) + + @traced + def schedule_batch(self, token_budget: int) -> List[RequestState]: + priority_states: List[RequestState] = [] + second_priority_states: List[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + priority_states.append(state) + elif state.status == RequestStatus.DECODING: + second_priority_states.append(state) + + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + + request_ids_to_remove_from_waiting = set() + + for state in candidates: + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + @traced + def finish_request(self, state: RequestState): + request_id = state.request_id + self.cache.free_blocks(request_id) + if request_id in self.active_requests: + del self.active_requests[request_id] + + +@traced(standalone=True) +def compute_optimal_blocks( + device: torch.device, + config: PretrainedConfig, + generation_config: GenerationConfig, + inputs: List[List[int]], + dtype: torch.dtype = torch.bfloat16, + safety_margin: float = 0.9, + median_prefill_length: Optional[int] = None, +): + """Calculate optimal number and size of blocks for the KV cache. + + Args: + device: The device where the model runs + config: The model configuration + generation_config: The generation configuration + inputs: Sample input sequences to estimate memory requirements + dtype: Data type for cache tensors + safety_margin: Fraction of available memory to use + median_prefill_length: Override for median prefill length calculation + + Returns: + Tuple of (num_blocks, block_size) + """ + # Extract model dimensions + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + num_hidden_layers = getattr(config, "num_hidden_layers", 40) + + # Get available device memory + if device.type == "cuda": + device_properties = torch.cuda.get_device_properties(device) + total_memory = device_properties.total_memory + allocated_memory = torch.cuda.memory_allocated(device) + reserved_memory = torch.cuda.memory_reserved(device) + available_memory = total_memory - max(allocated_memory, reserved_memory) + elif device.type == "mps": + logger.warning("MPS memory estimation is approximate. Using conservative defaults.") + return 2048, 256 + else: + logger.warning(f"Unsupported device type {device.type} for optimal block calculation. Using defaults.") + return 32, 128 + + # Apply safety margin + available_memory = int(available_memory * safety_margin) + if available_memory <= 0: + logger.warning("Not enough available memory. Using minimum configuration.") + return 8, 128 # Minimum viable configuration + + # Calculate memory per token + dtype_size = torch.tensor([], dtype=dtype).element_size() + memory_per_token = 2 * num_kv_heads * head_dim * dtype_size * num_hidden_layers # For K and V caches + + # Estimate sequence length requirements + tokens_to_generate = getattr(generation_config, "max_new_tokens", 20) + + if median_prefill_length is None and inputs: + non_empty_inputs = [len(seq) for seq in inputs if seq] + median_prefill_length = int(statistics.median(non_empty_inputs)) if non_empty_inputs else 64 + elif median_prefill_length is None: + median_prefill_length = 64 # Reasonable default if no inputs provided + + # Total sequence length including generated tokens + seq_length = median_prefill_length + tokens_to_generate + + # Calculate block parameters + MIN_BLOCK_SIZE = 16 + + # Estimate number of concurrent sequences + per_sequence_memory = seq_length * memory_per_token + max_concurrent_sequences = max(1, int(available_memory // per_sequence_memory)) + + # Total tokens that can fit in memory + total_tokens = available_memory // memory_per_token + + # Calculate block size (rounded to power of 2) + initial_block_size = max(MIN_BLOCK_SIZE, total_tokens // (max_concurrent_sequences * 2)) + block_size = 1 << (initial_block_size - 1).bit_length() # Round to power of 2 + + # Calculate number of blocks + num_blocks = max(1, total_tokens // block_size) + + logger.info( + f"Optimal cache: {num_blocks} blocks of size {block_size} " + f"(can handle ~{num_blocks * block_size // seq_length} sequences of length {seq_length})" + ) + + return int(num_blocks), int(block_size) + + +@dataclass +class PagedAttentionArgs: + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + cumulative_seqlens_q: torch.Tensor + cumulative_seqlens_k: torch.Tensor + max_seqlen_q: int + max_seqlen_k: int + write_index: torch.Tensor + read_index: torch.Tensor + logits_indices: torch.Tensor + block_tables: Dict[str, List[int]] + cache: PagedAttentionCache + use_cache: bool = False + + +@traced +def create_document_mask(cumulative_seqlens_q, cumulative_seqlens_k): + # Number of documents + valid_docs_q = cumulative_seqlens_q[1:] > cumulative_seqlens_q[:-1] + valid_docs_k = cumulative_seqlens_k[1:] > cumulative_seqlens_k[:-1] + num_valid_docs = min(valid_docs_q.sum(), valid_docs_k.sum()) + + # Trim to valid docs + cumulative_seqlens_q = cumulative_seqlens_q[: num_valid_docs + 1] + cumulative_seqlens_k = cumulative_seqlens_k[: num_valid_docs + 1] + + total_q = cumulative_seqlens_q[-1] + total_k = cumulative_seqlens_k[-1] + + q_indices = torch.arange(total_q, device=cumulative_seqlens_q.device) + k_indices = torch.arange(total_k, device=cumulative_seqlens_k.device) + + q_doc_ids = torch.bucketize(q_indices, cumulative_seqlens_q[1:], right=True) + k_doc_ids = torch.bucketize(k_indices, cumulative_seqlens_k[1:], right=False) + doc_mask = q_doc_ids[:, None] == k_doc_ids[None, :] + # apply causal mask where no decoding (same nb of q than k) + + is_causal = ~(cumulative_seqlens_q[1:] - cumulative_seqlens_q[:-1] == 1) * cumulative_seqlens_q[1:] + apply_causal = torch.bucketize(q_indices, is_causal, right=True)[:, None] == k_doc_ids + # TODO don't apply on prefill splitting + causal_mask = torch.triu(torch.ones(total_q, total_k, device=q_doc_ids.device), diagonal=1).bool() + doc_mask.masked_fill_((apply_causal & causal_mask), False) + return doc_mask + + +# Continuous Batch Processor (Internal Logic) +@attach_tracer() +class ContinuousBatchProcessor: + def __init__( + self, + cache: PagedAttentionCache, + config: PretrainedConfig, + generation_config: GenerationConfig, + input_queue: queue.Queue, + output_queue: queue.Queue, + stop_event: threading.Event, + model_device: torch.device, + model_dtype: torch.dtype, + scheduler: Scheduler, + streaming: bool = False, + ): + """Initialize the continuous batch processor. + + Args: + cache: The paged attention cache to use + generation_config: The generation configuration + input_queue: Queue for incoming requests + output_queue: Queue for outgoing results + stop_event: Event to signal processing should stop + model_device: Device for model inputs/outputs + model_dtype: Data type for model inputs/outputs + streaming: Whether to stream tokens as they're generated + """ + self.cache = cache + self.config = config + self.generation_config = generation_config + self.input_queue = input_queue + self.output_queue = output_queue + self.stop_event = stop_event + self.model_device = model_device + self.model_dtype = model_dtype + self.scheduler = scheduler + self.streaming = streaming + + self.requests_in_batch: List[RequestState] = [] + + # Get batch size parameters from generation config + self._configure_batch_parameters() + + # Set up metrics collector + self.metrics = ContinuousBatchProcessorMetrics(self.max_batch_tokens) + + self.setup_static_tensors() + + @traced(standalone=True) + def setup_static_tensors(self): + T = self.max_batch_tokens + max_token_budget = self.cache.num_blocks * self.cache.block_size + tensor_metadata = {"dtype": torch.int32, "device": self.model_device} + self.tensor_metadata = tensor_metadata + self.input_ids = torch.zeros((1, T), **tensor_metadata) + self.position_ids = torch.zeros((1, T), **tensor_metadata) + self.attention_mask = torch.zeros( + (1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device + ) + self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata) + self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata) + self.write_index = torch.zeros((T,), **tensor_metadata) + self.read_index = torch.zeros((max_token_budget,), **tensor_metadata) + self.logits_indices = torch.full((T,), -1, **tensor_metadata) + self.max_seqlen_q = 0 + self.max_seqlen_k = 0 + self.output_ids = torch.full((1, T), -1, **tensor_metadata) + + @traced + @torch.no_grad() + def reset_static_tensors(self): + """Reset static tensors for the next batch.""" + self.input_ids.zero_() + self.position_ids.zero_() + self.attention_mask.fill_(torch.finfo(self.model_dtype).min) + self.cumulative_seqlens_q.zero_() + self.cumulative_seqlens_k.zero_() + self.write_index.fill_(-1) + self.read_index.fill_(-1) + self.logits_indices.fill_(-1) + self.max_seqlen_q = 0 + self.max_seqlen_k = 0 + self.output_ids.zero_() + + def get_model_kwargs(self) -> PagedAttentionArgs: + """Get model keyword arguments for the current batch.""" + # torch.set_printoptions(threshold=100000,linewidth=10000) + return { + "input_ids": self.input_ids, + "position_ids": self.position_ids, + "attention_mask": self.attention_mask, + "cumulative_seqlens_q": self.cumulative_seqlens_q, + "cumulative_seqlens_k": self.cumulative_seqlens_k, + "write_index": self.write_index, + "read_index": self.read_index, + "logits_indices": self.logits_indices, + "max_seqlen_q": self.max_seqlen_q, + "max_seqlen_k": self.max_seqlen_k, + "block_tables": self.cache._block_tables, + "cache": self.cache, + "use_cache": False, + } + + def __repr__(self): + return ( + f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})" + + self.get_model_kwargs().__repr__() + ) + + @traced(standalone=True) + def _configure_batch_parameters(self): + """Set up batch processing parameters based on generation config.""" + # Calculate total cache capacity + total_cache_tokens = self.cache.num_blocks * self.cache.block_size + + # Get or calculate max tokens per batch + user_batch_tokens = getattr(self.generation_config, "max_batch_tokens", None) + if user_batch_tokens is not None: + self.max_batch_tokens = user_batch_tokens + else: + # Default to 1/8 of total cache capacity, adjusted for context + self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048) + recommended_batch_size = min(total_cache_tokens // 8, self.max_context_len) + self.max_batch_tokens = max(64, recommended_batch_size) + + # Context length and EOS token + self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048) + + @traced + def _get_new_requests(self): + """Pull new requests from the input queue and add to waiting list.""" + while not self.input_queue.empty(): + try: + state = self.input_queue.get_nowait() + if state is None: # Sentinel value + continue + self.scheduler.add_waiting_request(state) + + except queue.Empty: + break + except Exception as e: + logger.error(f"Error processing new request: {e}", exc_info=True) + state: RequestState = locals().get("state") + if state is not None: + self._handle_request_error(e, state) + + @traced + def _handle_request_error(self, error, state: RequestState): + """Handle general request processing error.""" + state.status = RequestStatus.FAILED + state.error = str(error) + + # Include any generated tokens if this is an active request + if isinstance(state.request_id, str): + state.static_outputs = self.scheduler.get_active_request_static_outputs(state.request_id) + else: + state.static_outputs = [] + + self.metrics.record_request_completion(state.created_time, state.request_id) + self.output_queue.put(state.to_generation_output()) + + @traced + def prepare_next_batch(self): + """Prepare tensors and metadata for the next model forward pass.""" + # Get new requests from the queue + self._get_new_requests() + if not self.scheduler.has_pending_requests(): + return None + + self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests)) + + self.requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens) + if not self.requests_in_batch: + return None + + # Get the request objects for this batch + self.reset_static_tensors() + position_ids = [] + input_ids = [] + read_index = [] + write_index = [] + cumulative_seqlens_q = [0] + cumulative_seqlens_k = [0] + logits_indices = [] + self.metrics.record_batch_metrics(self.requests_in_batch) + + for state in self.requests_in_batch: + next_input_ids = state.prompt_ids + input_ids.extend(next_input_ids) + past_length = state.position_offset + query_length = len(next_input_ids) + key_length = query_length + past_length + cache_index = list(range(key_length)) + + positions_to_add = cache_index[past_length:] + read_indices = self.cache._get_physical_indices(state, cache_index) + write_indices = read_indices[-query_length:] + + position_ids.extend(positions_to_add) + read_index.extend(read_indices) + write_index.extend(write_indices) + cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length) + cumulative_seqlens_k.append(cumulative_seqlens_k[-1] + key_length) + if len(state.remaining_prompt_ids) == 0: + logits_indices.append(cumulative_seqlens_q[-1] - 1) + self.max_seqlen_q = max(self.max_seqlen_q, query_length) + self.max_seqlen_k = max(self.max_seqlen_k, key_length) + state.position_offset += query_length + + logger.warning( + f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}" + ) + self._build_tensors( + input_ids, + position_ids, + read_index, + write_index, + cumulative_seqlens_q, + cumulative_seqlens_k, + logits_indices, + ) + + self.metrics.record_kv_cache_memory_metrics(self.cache) + + @traced + def _build_tensors( + self, + input_ids, + position_ids, + read_index, + write_index, + cumulative_seqlens_q, + cumulative_seqlens_k, + logits_indices, + ): + to_tensor = partial(torch.tensor, **self.tensor_metadata) + self.input_ids[:, : len(input_ids)] = to_tensor(input_ids) + self.position_ids[:, : len(position_ids)] = to_tensor(position_ids) + self.write_index[: len(write_index)] = to_tensor(write_index) + self.read_index[: len(read_index)] = to_tensor(read_index) + self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q) + self.cumulative_seqlens_k[: len(cumulative_seqlens_k)] = to_tensor(cumulative_seqlens_k) + self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices) + min_value = torch.finfo(self.model_dtype).min + if self.config._attn_implementation != "paged_attention": # we set `is_causal` to True in paged call` + for i in range(len(cumulative_seqlens_q) - 1): + if ( + cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] + < cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i] + and cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] >= 1 + ): + diagonal = ( + cumulative_seqlens_k[i + 1] - (cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]) + 1 + ) + diagonal = diagonal - cumulative_seqlens_k[i] + else: + diagonal = 1 + query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1]) + key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1]) + + mask = torch.triu( + torch.full( + self.attention_mask[..., query_range, key_range].shape, + min_value, + dtype=self.model_dtype, + device=self.model_device, + ), + diagonal=diagonal, + ) + self.attention_mask[..., query_range, key_range] = mask + + @traced + def _sync(self): + return self.output_ids.tolist()[0] # should be the only synch we do + + @traced + def _maybe_send_output(self, state: RequestState, token: int): + """Send output to the queue based on streaming mode and request state.""" + if self.streaming: + state.next_token = token + self.output_queue.put(state.to_generation_output()) + elif state.status == RequestStatus.FINISHED: + self.output_queue.put(state.to_generation_output()) + + @traced + def update_batch(self): + """Update request states based on generated tokens.""" + out_tokens = self._sync() + finished_request_ids = [] + for i, state in enumerate(self.requests_in_batch): + req_id = state.request_id + if len(state.remaining_prompt_ids) == 0: + self.metrics.record_ttft_metric(state.created_time, state.request_id) + state.status = RequestStatus.DECODING + token = out_tokens[self.logits_indices[i]] + state.static_outputs.extend([token]) + state.prompt_ids = [token] + if state.update_with_token(token): + self.metrics.record_request_completion(state.created_time, state.request_id) + self.scheduler.finish_request(state) + finished_request_ids.append(req_id) + self._maybe_send_output(state, token) + elif state.status == RequestStatus.PREFILLING_SPLIT: + state.status = RequestStatus.SPLIT_PENDING_REMAINDER + + @traced + def has_pending_requests(self) -> bool: + """Check if there are any active or waiting requests.""" + return self.scheduler.has_pending_requests() + + @traced + def handle_batch_error(self, error): + """Handle errors during batch processing.""" + failed_reqs = self.requests_in_batch + for req in failed_reqs: + self._handle_request_error(error, req) + self.scheduler.finish_request(req) + + @traced + def fail_all_requests(self, error): + """Fail all active requests with the given error. + + Args: + error: The error to report in the failure message + """ + for state in self.scheduler.active_requests.values(): + self._handle_request_error(error, state) + self.scheduler.finish_request(state) + + # Also fail any requests in the waiting queue + for req_id in list(self.scheduler.waiting_requests.keys()): + state = self.scheduler.waiting_requests.pop(req_id) + self._handle_request_error(error, state) + + # Clear the ordering queue + self.scheduler.waiting_requests_order.clear() + + +SCHEDULER_MAPPING = { + "fifo": FIFOScheduler, + "prefill_first": PrefillFirstScheduler, +} + + +# Manager Class (User Interface) +@attach_tracer() +class ContinuousBatchingManager: + """Manager for handling continuous batching of generation requests. + + This class provides the user interface for submitting generation requests, + retrieving results, and managing the background generation thread. + """ + + def __init__(self, model, generation_config: GenerationConfig, max_queue_size=0, streaming: bool = True): + """Initialize the continuous batching manager. + + Args: + model: The language model for generation + generation_config: Configuration for generation parameters + max_queue_size: Maximum size of the request queue (0 = unlimited) + streaming: Whether to stream tokens as they are generated + """ + self.model = model + self.generation_config = generation_config + self.input_queue = queue.Queue(maxsize=max_queue_size) + self.output_queue = queue.Queue() + self.stop_event = threading.Event() + self.streaming = streaming + self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) + self._generation_thread = None + self._request_counter = 0 + self._request_lock = threading.Lock() + self.model.generation_config.top_p = None + self.do_sample = getattr(generation_config, "do_sample", True) + self.logit_processor = self.model._get_logits_processor(self.model.generation_config) + self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True) + self.profile = getattr(generation_config, "profile", False) + + @traced + def start(self): + """Start the background generation thread.""" + if self._generation_thread is not None and self._generation_thread.is_alive(): + logger.warning("Manager thread is already running.") + return + + self._result_queue = queue.Queue() + self._generation_thread = threading.Thread(target=self._run_generation_loop) + self._generation_thread.start() + logger.info("Continuous batching manager started.") + + def is_running(self): + """Check if the background generation thread is running.""" + return self._generation_thread is not None and self._generation_thread.is_alive() + + def stop(self, block: bool = False, timeout: Optional[float] = None): + """Signal the background thread to stop. + + Args: + block: Whether to wait for the thread to stop + timeout: Maximum time to wait for the thread to stop + """ + if self._generation_thread is None: + logger.warning("Manager not started.") + return + + if not self.stop_event.is_set(): + self.stop_event.set() + logger.info("Stopping continuous batching manager...") + + if block: + self.join(timeout) + + def join(self, timeout: Optional[float] = None): + """Wait for the background thread to finish. + + Args: + timeout: Maximum time to wait for the thread to stop + """ + if self._generation_thread is not None: + self._generation_thread.join(timeout=timeout) + if self._generation_thread.is_alive(): + logger.warning("Generation thread did not exit after join timeout.") + else: + logger.info("Continuous Batching Manager stopped.") + self._generation_thread = None + + def add_request( + self, input_ids: List[int], request_id: Optional[str] = None, max_new_tokens: Optional[int] = None + ) -> str: + """Add a new generation request to the queue. + + Args: + input_ids: Input token IDs to use as prompt + request_id: Optional custom request ID (auto-generated if None) + **kwargs: Additional generation parameters + + Returns: + str: The request ID + """ + if request_id is None: + with self._request_lock: + request_id = f"req_{self._request_counter}" + self._request_counter += 1 + + max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens + + state = RequestState( + request_id=request_id, + prompt_ids=list(input_ids), + full_prompt_ids=list(input_ids), + max_new_tokens=max_new_tokens, + eos_token_id=self.generation_config.eos_token_id, + ) + + # Use block=True with timeout to handle backpressure if queue is full + self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg? + logger.debug(f"Added request {request_id} to queue.") + return request_id + + def add_requests(self, inputs: List[List[int]], **kwargs): + for i, input_ids in enumerate(inputs): + # Assign a predictable request ID for ordering results later + req_id = f"batch_req_{i}" + self.add_request(input_ids, request_id=req_id, **kwargs) + + def get_result(self, timeout=None) -> Optional[GenerationOutput]: + """Retrieve one result from the output queue. + + Args: + timeout: Maximum time to wait for a result + + Returns: + Optional[Dict]: The result data or None if timeout + """ + if self._generation_thread is None and self.output_queue.empty(): + return None + try: + result = self.output_queue.get(block=True, timeout=timeout) + logger.debug(f"Retrieved result for request {result.request_id}") + return result + except queue.Empty: + return None + + def __iter__(self): + """Iterate over results as they become available.""" + while ( + self._generation_thread is not None and self._generation_thread.is_alive() or not self.output_queue.empty() + ): + result = self.get_result(timeout=0.1) # allow the model to run for 10 seconds + if result is not None: + yield result + + @traced + def warmup(self, batch_processor): + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + # Warmup the model with a dummy forward pass + self._generation_step(batch_processor) + torch.cuda.current_stream().wait_stream(stream) + + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph): + self._generation_step(batch_processor) + + @traced + # @torch.compile + def _generation_step(self, batch_processor: ContinuousBatchProcessor): + """Perform a single generation step. This is cuda graphed""" + batch_data = batch_processor.get_model_kwargs() + with torch.no_grad(): + logits = self._model_forward(batch_data) + if self.log_prob_generation: + batch_processor.output_probs.copy_(logits) # TODO + probs = self._process_logit(batch_data, logits) + self._sample(batch_processor, probs) + + @traced(span_name="model_forward") + def _model_forward(self, batch_data): + return self.model(**batch_data).logits + + @traced(span_name="logit_processing") + def _process_logit(self, batch_data, logits): + return self.logit_processor(batch_data["input_ids"], logits) + + @traced(span_name="sampling") + def _sample(self, batch_processor: ContinuousBatchProcessor, probs): + if self.do_sample: # sample + probs = nn.functional.softmax(probs, dim=-1) + next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + batch_processor.output_ids.copy_(next_tokens) + + def _run_generation_loop(self): + """Main processing loop running in the background thread.""" + batch_processor = None + try: + paged_attention_cache = PagedAttentionCache( + self.model.config, + self.generation_config, + self.model.device, + self.model.dtype, + ) + + scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler) + if scheduler is None: + logger.warning(f"Scheduler '{scheduler}' not found. Defaulting to FIFO.") + scheduler = FIFOScheduler + + batch_processor = ContinuousBatchProcessor( + paged_attention_cache, + self.model.config, + self.generation_config, + self.input_queue, + self.output_queue, + self.stop_event, + self.model.device, + self.model.dtype, + scheduler(paged_attention_cache), + self.streaming, + ) + is_first = True + + if self.profile: + tracing_schedule = schedule(skip_first=2, warmup=3, active=200, repeat=100, wait=1) + trace_handler = tensorboard_trace_handler( + dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile" + ) + activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + with profile( + activities=activities, + schedule=tracing_schedule, + on_trace_ready=trace_handler, + record_shapes=False, + with_stack=True, + ) as prof: + while not self.stop_event.is_set() or batch_processor.has_pending_requests(): + self._inner_generation_loop(batch_processor, is_first) + if is_first: + is_first = False + prof.step() + else: + while not self.stop_event.is_set() or batch_processor.has_pending_requests(): + self._inner_generation_loop(batch_processor, is_first) + if is_first: + is_first = False + + except Exception as e: + logger.error(f"Error in generation loop: {e}", exc_info=True) + self._handle_critical_error(e, batch_processor) + finally: + logger.info("Generation loop finished.") + + @traced(span_name="generation_loop") + def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor, is_first: bool = False): + if torch.cuda.is_available(): + torch.cuda.synchronize() + batch_processor.prepare_next_batch() + if torch.cuda.is_available() and self.use_cuda_graph: + if is_first: + self.warmup(batch_processor) + elif hasattr(self, "graph"): + try: + self._graph_replay() + except Exception as e: + logger.error(f"Model forward pass failed: {e}", exc_info=True) + batch_processor.handle_batch_error(e) + return + else: + self._generation_step(batch_processor) + else: + self._generation_step(batch_processor) + if torch.cuda.is_available(): + torch.cuda.synchronize() + batch_processor.update_batch() + + @traced(span_name="graph_replay") + def _graph_replay(self): + self.graph.replay() + + @traced + def _handle_critical_error(self, error, batch_processor: Optional[ContinuousBatchProcessor]): + """Handle critical errors that terminate the generation loop.""" + # Signal stop + self.stop_event.set() + + # Fail pending requests in input queue + try: + while True: + req_data = self.input_queue.get_nowait() + if batch_processor is not None: + batch_processor._handle_request_error(error, req_data) + except queue.Empty: + pass + + # Fail active requests + if batch_processor is not None: + batch_processor.fail_all_requests(error) + + +class ContinuousMixin: + """Mixin class for models to add continuous batching capabilities.""" + + def init_continuous_batching( + self, + generation_config: Optional[GenerationConfig] = None, + max_queue_size: int = 0, + scheduler: str = "fifo", + streaming: bool = False, + ) -> ContinuousBatchingManager: + """Initialize a manager for continuous batching inference. + + Args: + generation_config: Custom generation configuration + max_queue_size: Maximum size of the input request queue + streaming: Whether to stream tokens as they are generated + + Returns: + `ContinuousBatchingManager`: The manager instance to add requests and retrieve results. + """ + if not hasattr(self, "config") or not hasattr(self, "device") or not hasattr(self, "dtype"): + raise AttributeError("Model must have 'config', 'device', and 'dtype' attributes.") + + gen_config = generation_config if generation_config is not None else self.generation_config + if gen_config is None: + raise ValueError("A GenerationConfig must be provided or set in the model.") + + if gen_config.eos_token_id is None: + logger.warning("`eos_token_id` not set in GenerationConfig. Setting to -1 (disabled).") + gen_config.eos_token_id = -1 + + # Create and return the manager + return ContinuousBatchingManager( + model=self, generation_config=gen_config, max_queue_size=max_queue_size, streaming=streaming + ) + + @traced + @torch.inference_mode() + def generate_batch( + self, + inputs: List[List[int]], + generation_config: Optional[GenerationConfig] = None, + progress_bar: bool = True, + **kwargs, + ) -> List[List[int]]: + """Generate sequences for a batch of prompts using continuous batching. + + Args: + inputs: List of input token sequences (prompts) + generation_config: Optional generation configuration + **kwargs: Additional generation parameters + + Returns: + `List[List[int]]`: A list containing the generated sequences (including prompt tokens + if not handled otherwise) for each input prompt, in the same order. + Returns an empty list `[]` for requests that failed. + """ + if not inputs: + return [] + + # Initialize manager with the batch inputs + manager = self.init_continuous_batching(generation_config=generation_config) + manager.start() + results = {} + num_requests = len(inputs) + try: + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm([logger]): + with tqdm( + total=num_requests, + disable=(not progress_bar), + desc=f"Solving {num_requests} requests", + unit="request", + ) as pbar: + manager.add_requests(inputs, **kwargs) + finished_count = 0 + while finished_count < num_requests: + result = manager.get_result(timeout=1) + if result: + req_id = result.request_id + if result.status == RequestStatus.FINISHED: + results[req_id] = result + finished_count += 1 + pbar.update(1) + else: + if not manager.is_running(): + logger.error("Generation thread terminated unexpectedly.") + break + + except Exception as e: + logger.error(f"Error during batch generation: {e}", exc_info=True) + finally: + manager.stop(block=True, timeout=5.0) + return results diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index ddd718cbb8a..3c8c4795a84 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -531,13 +531,16 @@ class FlaxGenerationMixin: if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) else begin_index + 1 ) - if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0: + if ( + getattr(generation_config, "forced_decoder_ids", None) is not None + and len(generation_config.forced_decoder_ids) > 0 + ): # generation starts after the last token that is forced begin_index += generation_config.forced_decoder_ids[-1][0] processors.append( FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) ) - if generation_config.forced_decoder_ids is not None: + if getattr(generation_config, "forced_decoder_ids", None) is not None: forced_decoder_ids = [ [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids ] diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 6e0f0154abd..a4e8b5eda0d 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -2051,6 +2051,10 @@ class WhisperNoSpeechDetection(LogitsProcessor): self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs} self.inputs["input_features"] = self.inputs.pop("inputs") + # Whisper encoder-decoder does not accept the input_ids as input + if "input_ids" not in inspect.signature(self.model.forward).parameters: + self.inputs.pop("input_ids", None) + @property def no_speech_prob(self): return self._no_speech_prob diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 510186cafc0..ae77f32e269 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -1490,14 +1490,14 @@ class TFGenerationMixin: if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) else begin_index + 1 ) - if generation_config.forced_decoder_ids is not None: + if getattr(generation_config, "forced_decoder_ids", None) is not None: begin_index += generation_config.forced_decoder_ids[-1][ 0 ] # generation starts after the last token that is forced processors.append( TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) ) - if generation_config.forced_decoder_ids is not None: + if getattr(generation_config, "forced_decoder_ids", None) is not None: processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids)) processors = self._merge_criteria_processor_list(processors, logits_processor) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 49dc4b8df72..713d57a8994 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -79,6 +79,7 @@ from .configuration_utils import ( GenerationConfig, GenerationMode, ) +from .continuous_batching import ContinuousMixin from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, @@ -352,7 +353,7 @@ GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDec GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] -class GenerationMixin: +class GenerationMixin(ContinuousMixin): """ A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes. Inheriting from this class causes the model to have special generation-related behavior, such as loading a @@ -635,7 +636,7 @@ class GenerationMixin: and attention_mask is not None and attention_mask.ndim == 2 ): - if model_inputs["inputs_embeds"] is not None: + if not self.config.is_encoder_decoder and model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape else: batch_size, sequence_length = model_inputs[input_ids_key].shape[:2] @@ -654,7 +655,6 @@ class GenerationMixin: # If it's not defined, it means the model uses the new general mask API if causal_mask_creation_function is None: # can't be found - output_attentions = kwargs.get("output_attentions", False) token_type_ids = getattr(model_input, "token_type_ids", None) # Some models may overwrite the general one causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate) @@ -665,7 +665,6 @@ class GenerationMixin: attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, token_type_ids=token_type_ids, ) else: @@ -1099,10 +1098,10 @@ class GenerationMixin: def _get_logits_processor( self, generation_config: GenerationConfig, - input_ids_seq_length: int, - encoder_input_ids: torch.LongTensor, - prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], - logits_processor: Optional[LogitsProcessorList], + input_ids_seq_length: Optional[int] = None, + encoder_input_ids: torch.LongTensor = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, device: Optional[str] = None, model_kwargs: Optional[Dict[str, Any]] = None, negative_prompt_ids: Optional[torch.Tensor] = None, @@ -1114,6 +1113,8 @@ class GenerationMixin: """ # instantiate processors list processors = LogitsProcessorList() + if logits_processor is None: + logits_processor = [] if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1: processors.append( @@ -1183,7 +1184,7 @@ class GenerationMixin: ) if ( generation_config.min_length is not None - and generation_config._eos_token_tensor is not None + and getattr(generation_config, "_eos_token_tensor", None) is not None and generation_config.min_length > 0 ): processors.append( @@ -1195,7 +1196,7 @@ class GenerationMixin: ) if ( generation_config.min_new_tokens is not None - and generation_config._eos_token_tensor is not None + and getattr(generation_config, "_eos_token_tensor", None) is not None and generation_config.min_new_tokens > 0 ): processors.append( @@ -2344,9 +2345,15 @@ class GenerationMixin: if custom_generate is not None: trust_remote_code = kwargs.pop("trust_remote_code", None) # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them: - # they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to - # methods from `GenerationMixin` through `model`. - global_keys_to_exclude = {"self", "kwargs"} + # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to + # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`. + global_keys_to_exclude = { + "self", + "kwargs", + "global_keys_to_exclude", + "trust_remote_code", + "custom_generate", + } generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude} generate_arguments.update(kwargs) diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py index 4446aaa6470..42a6b785841 100644 --- a/src/transformers/image_processing_base.py +++ b/src/transformers/image_processing_base.py @@ -28,8 +28,6 @@ from .feature_extraction_utils import BatchFeature as BaseBatchFeature from .utils import ( IMAGE_PROCESSOR_NAME, PushToHubMixin, - add_model_info_to_auto_map, - add_model_info_to_custom_pipelines, cached_file, copy_func, download_url, @@ -380,14 +378,6 @@ class ImageProcessingMixin(PushToHubMixin): logger.info( f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}" ) - if "auto_map" in image_processor_dict: - image_processor_dict["auto_map"] = add_model_info_to_auto_map( - image_processor_dict["auto_map"], pretrained_model_name_or_path - ) - if "custom_pipelines" in image_processor_dict: - image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines( - image_processor_dict["custom_pipelines"], pretrained_model_name_or_path - ) return image_processor_dict, kwargs @@ -508,11 +498,7 @@ class ImageProcessingMixin(PushToHubMixin): Register this class with a given auto class. This should only be used for custom image processors as the ones in the library are already mapped with `AutoImageProcessor `. - - This API is experimental and may have some slight breaking changes in the next releases. - - Args: auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`): diff --git a/src/transformers/integrations/eager_paged.py b/src/transformers/integrations/eager_paged.py new file mode 100644 index 00000000000..9893e10c89a --- /dev/null +++ b/src/transformers/integrations/eager_paged.py @@ -0,0 +1,45 @@ +from typing import Optional + +import torch +from torch import nn + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_paged_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + cache = kwargs.pop("cache", None) + if cache is not None: + key, value = cache.update(key, value, module.layer_idx, **kwargs) + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index eb17dab55af..bd4b30a3d12 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -11,7 +11,6 @@ # specific language governing permissions and limitations under the License. import logging -from contextlib import contextmanager from typing import Callable, Optional import torch @@ -110,14 +109,13 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long) example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long) - with patch_mask_interface(): - exported_program = torch.export.export( - self.model, - args=(example_input_ids, example_cache_position), - kwargs={}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) + exported_program = torch.export.export( + self.model, + args=(example_input_ids, example_cache_position), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) return exported_program @staticmethod @@ -456,24 +454,6 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): return outputs.logits -@contextmanager -def patch_mask_interface(): - """ - Context manager to locally use a simple dict instead of `AttentionMaskInterface`, as otherwise export will fail - with `strict=True` due to dynamo skip rules, i.e. `torch._dynamo.exc.Unsupported: 'inline in skipfiles: - Mapping.__contains__ | __contains__, skipped according trace_rules.lookup SKIP_DIRS'`. - Note that this seem to be an issue only for python<3.11. - """ - import transformers - - original = transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS - transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = ALL_MASK_ATTENTION_FUNCTIONS._global_mapping - try: - yield - finally: - transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS = original - - def convert_and_export_with_cache( model: PreTrainedModel, example_input_ids: Optional[torch.Tensor] = None, @@ -515,14 +495,13 @@ def convert_and_export_with_cache( ) if is_torch_greater_or_equal("2.6.0"): - with patch_mask_interface(): - exported_program = torch.export.export( - TorchExportableModuleWithStaticCache(model), - args=(example_input_ids, example_cache_position), - kwargs={}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) + exported_program = torch.export.export( + TorchExportableModuleWithStaticCache(model), + args=(example_input_ids, example_cache_position), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) else: if dynamic_shapes is not None: logging.warning( @@ -534,14 +513,13 @@ def convert_and_export_with_cache( # # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. - with patch_mask_interface(): - exported_program = torch.export._trace._export( - TorchExportableModuleWithStaticCache(model), - args=(example_input_ids,), - kwargs={"cache_position": example_cache_position}, - pre_dispatch=False, - strict=True, - ) + exported_program = torch.export._trace._export( + TorchExportableModuleWithStaticCache(model), + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) return exported_program @@ -634,10 +612,9 @@ class Seq2SeqLMExportableModule(torch.nn.Module): # Export the encoder with torch.no_grad(): - with patch_mask_interface(): - exported_encoder = torch.export.export( - wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True - ) + exported_encoder = torch.export.export( + wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True + ) return exported_encoder @@ -657,17 +634,16 @@ class Seq2SeqLMExportableModule(torch.nn.Module): # Export the decoder with torch.no_grad(): - with patch_mask_interface(): - exported_decoder = torch.export.export( - wrapped_decoder, - (decoder_input_ids, encoder_hidden_states, cache_position), - dynamic_shapes={ - "decoder_input_ids": None, - "encoder_hidden_states": {1: encoder_seq_len_dim}, - "cache_position": None, - }, - strict=True, - ) + exported_decoder = torch.export.export( + wrapped_decoder, + (decoder_input_ids, encoder_hidden_states, cache_position), + dynamic_shapes={ + "decoder_input_ids": None, + "encoder_hidden_states": {1: encoder_seq_len_dim}, + "cache_position": None, + }, + strict=True, + ) return exported_decoder diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index a78166ed040..4f76e65a847 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -3,8 +3,11 @@ from typing import Optional, Tuple import torch from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ..utils import logging +logger = logging.get_logger(__name__) + _use_top_left_mask = flash_attn_supports_top_left_mask() @@ -20,6 +23,12 @@ def flash_attention_forward( softcap: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: + if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None: + logger.warning_once( + "`flash_attention_2` does not support `output_attentions=True` or `head_mask`." + " Please set your attention to `eager` if you want any of these features." + ) + # This is before the transpose seq_len = query.shape[2] diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py new file mode 100644 index 00000000000..b0463d95248 --- /dev/null +++ b/src/transformers/integrations/flash_paged.py @@ -0,0 +1,64 @@ +import torch + +from ..generation.continuous_batching import PagedAttentionCache +from ..utils import is_flash_attn_2_available + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + + +def paged_attention_forward( + module: torch.nn.Module, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: torch.Tensor = None, + cache: PagedAttentionCache = None, + cumulative_seqlens_q=None, + cumulative_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + block_tables=None, + **kwargs, +) -> torch.Tensor: + r"""Perform the forward pass of attention with paged key-value cache. + + This function handles the cache updates and performs the attention computation + using the flash_attn_varlen_func for efficient processing. + + Args: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full k + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full v + cumulative_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cumulative_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + """ + k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs) + + attn_output = flash_attn_varlen_func( + q.transpose(1, 2).squeeze(0), + k.transpose(1, 2).squeeze(0), + v.transpose(1, 2).squeeze(0), + cumulative_seqlens_q.to(torch.int32), + cumulative_seqlens_k.to(torch.int32), + max_seqlen_q, + max_seqlen_k, + softmax_scale=module.scaling, + causal=True, # kind of a must, it automatically aligns the mask for q < k + window_size=(-1, -1), # -1 means infinite context window + # block_table=block_tables, -> torch.Tensor + # **kwargs, + ) + + return attn_output, None diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index afdaba5199d..1e1228873f1 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -31,13 +31,15 @@ from typing import Optional, Tuple, Union import torch from packaging import version -from ..utils import is_torch_flex_attn_available +from ..utils import is_torch_flex_attn_available, logging from ..utils.import_utils import _torch_version, is_torchdynamo_compiling if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask, flex_attention - from torch.nn.attention.flex_attention import create_block_mask as create_block_causal_mask_flex + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + + +logger = logging.get_logger(__name__) class WrappedFlexAttention: @@ -98,21 +100,23 @@ def compile_friendly_flex_attention( Offset = Union[torch.Tensor, int] +# TODO: deprecate / rename to make_flex_block_mask for clarity as it's not only causal anymore def make_flex_block_causal_mask( attention_mask_2d: torch.Tensor, attention_chunk_size: Optional[int] = None, query_length=None, key_length=None, offsets: Optional[Tuple[Offset, Offset]] = None, + is_causal: Optional[bool] = True, ) -> "BlockMask": """ IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`, and will be removed in a future version without warnings. New code should not use it. It is only kept here for BC for now, while models using it are being patched accordingly. - Create a block causal document mask for a batch of sequences, both packed and unpacked. - Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. - The resultant BlockMask is a compressed representation of the full block causal + Create a block (causal) document mask for a batch of sequences, both packed and unpacked. + Create Block (causal) logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. + The resultant BlockMask is a compressed representation of the full (causal) block mask. BlockMask is essential for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/ @@ -170,7 +174,21 @@ def make_flex_block_causal_mask( causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx) return chunk_mask & causal_doc_mask - mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod + def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + """ + Utilizes default attention mask to enable encoder and encoder-decoder + attention masks. + """ + document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] + # kv indexing is crucial in order to work correctly + padding_mask = attention_mask_2d[batch_idx, kv_idx] > 0 + final_mask = padding_mask & document_mask + return final_mask + + if not is_causal: + mask_mod_maybe_combined = default_mask_mod + else: + mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod if offsets is not None: q_offset = offsets[0] @@ -182,7 +200,8 @@ def make_flex_block_causal_mask( return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv) else: mask_mod = mask_mod_maybe_combined - return create_block_causal_mask_flex( + + return create_block_mask( mask_mod=mask_mod, B=batch_size, H=None, # attention head @@ -216,21 +235,32 @@ def flex_attention_forward( head_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + if head_mask is not None: + logger.warning_once( + "`flex_attention` does not support `head_mask`. Please set your attention to `eager` if you want this feature." + ) + + if kwargs.get("dropout", 0.0) > 0: + raise ValueError( + "`flex_attention` does not support `dropout`. Please use it with inference" + " only (`model.eval()`) or turn off the attention dropout in the respective config." + ) + block_mask = None - causal_mask = None + score_mask = None if isinstance(attention_mask, BlockMask): block_mask = attention_mask else: - causal_mask = attention_mask + score_mask = attention_mask - if causal_mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] + if score_mask is not None: + score_mask = score_mask[:, :, :, : key.shape[-2]] def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): if softcap is not None: score = softcap * torch.tanh(score / softcap) - if causal_mask is not None: - score = score + causal_mask[batch_idx][0][q_idx][kv_idx] + if score_mask is not None: + score = score + score_mask[batch_idx][0][q_idx][kv_idx] if head_mask is not None: score = score + head_mask[batch_idx][head_idx][0][0] return score diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py index bb515540d14..e32af9f4bc9 100644 --- a/src/transformers/integrations/npu_flash_attention.py +++ b/src/transformers/integrations/npu_flash_attention.py @@ -37,6 +37,8 @@ if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAU "or 3 (down-right aligned causal mask)." ) +ATTN_MASK_NPU = None + def is_npu_fa2_top_left_aligned_causal_mask(): return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False @@ -171,7 +173,9 @@ def npu_flash_attn_func( head_num = q.shape[2] output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] else: - attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() + global ATTN_MASK_NPU + if ATTN_MASK_NPU is None: + ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() head_num = q.shape[2] output = torch_npu.npu_fusion_attention( q, @@ -181,7 +185,7 @@ def npu_flash_attn_func( "BSND", keep_prob=keep_prob, scale=softmax_scale, - atten_mask=attn_mask_npu, + atten_mask=ATTN_MASK_NPU, sparse_mode=SPARSE_MODE, )[0] @@ -222,7 +226,9 @@ def npu_flash_attn_varlen_func( actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), )[0] else: - attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() + global ATTN_MASK_NPU + if ATTN_MASK_NPU is None: + ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() head_num = q.shape[1] output = torch_npu.npu_fusion_attention( q, @@ -231,7 +237,7 @@ def npu_flash_attn_varlen_func( head_num, pse=None, padding_mask=None, - atten_mask=attn_mask_npu, + atten_mask=ATTN_MASK_NPU, scale=softmax_scale, keep_prob=keep_prob, input_layout="TND", diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 9c924c048ad..247cd282167 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -2,6 +2,11 @@ from typing import Optional, Tuple import torch +from ..utils import logging + + +logger = logging.get_logger(__name__) + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -26,13 +31,18 @@ def sdpa_attention_forward( is_causal: Optional[bool] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: + if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None: + logger.warning_once( + "`sdpa` attention does not support `output_attentions=True` or `head_mask`." + " Please set your attention to `eager` if you want any of these features." + ) + if hasattr(module, "num_key_value_groups"): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) - causal_mask = attention_mask - if attention_mask is not None and causal_mask.ndim == 4: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -44,7 +54,9 @@ def sdpa_attention_forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` if is_causal is None: - is_causal = query.shape[2] > 1 and causal_mask is None + # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag + # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns + is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools. @@ -55,7 +67,7 @@ def sdpa_attention_forward( query, key, value, - attn_mask=causal_mask, + attn_mask=attention_mask, dropout_p=dropout, scale=scaling, is_causal=is_causal, diff --git a/src/transformers/integrations/sdpa_paged.py b/src/transformers/integrations/sdpa_paged.py new file mode 100644 index 00000000000..640db16d0de --- /dev/null +++ b/src/transformers/integrations/sdpa_paged.py @@ -0,0 +1,51 @@ +from typing import Optional, Tuple + +import torch + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def sdpa_attention_paged_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + cache = kwargs.pop("cache", None) + if cache is not None: + key, value = cache.update(key, value, module.layer_idx, **kwargs) + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + causal_mask = attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=False, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index a9f8940e72e..769845e7e8e 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -900,7 +900,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: Optional[dict[str, str]]): unused_rules = tp_plan for key in generic_keys: - param_name, _ = key.rsplit(".", 1) if "." in key else key + param_name = key.rsplit(".", 1)[0] if "." in key else key generic_param_name = re.sub(r"\d+", "*", param_name) if generic_param_name in tp_plan: diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 36538882af5..cb502206d78 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -25,11 +25,16 @@ from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_o if is_torch_flex_attn_available(): - from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex from torch.nn.attention.flex_attention import BlockMask, create_block_mask - +else: + # Register a fake type to avoid crashing for annotations and `isinstance` checks + BlockMask = torch.Tensor _is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True) +_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True) + +if _is_torch_greater_or_equal_than_2_6: + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex def and_masks(*mask_functions: list[Callable]) -> Callable: @@ -415,14 +420,14 @@ def sdpa_mask_older_torch( # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 - if allow_torch_fix: + if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) return causal_mask # We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions # (especially mask_function indexing a tensor, such as the padding mask function) -sdpa_mask = sdpa_mask_recent_torch if is_torch_flex_attn_available() else sdpa_mask_older_torch +sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch def eager_mask( @@ -522,7 +527,7 @@ def flex_attention_mask( mask_function: Callable = causal_mask_function, attention_mask: Optional[torch.Tensor] = None, **kwargs, -) -> "BlockMask": +) -> BlockMask: """ Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/ @@ -623,7 +628,11 @@ def _preprocess_mask_arguments( return True, attention_mask, None, None # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask! - if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS: + # Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise + # full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11 + # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped + # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11 + if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping: return True, None, None, None # Move the mask to correct device, and potentially switch dtype for efficiency @@ -640,33 +649,15 @@ def _preprocess_mask_arguments( return False, attention_mask, kv_length, kv_offset -def _get_mask_interface(config: PretrainedConfig, output_attentions: bool = False) -> Callable: - """ - Return the mask interface (a function) to be used, based on the type of attention found in the config. - - Args: - config (`PretrainedConfig`): - The model config. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. - """ - mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] - # Sdpa fallbacks to eager in the Attention modules if `output_attentions=True` - if config._attn_implementation == "sdpa" and output_attentions: - mask_interface = ALL_MASK_ATTENTION_FUNCTIONS["eager"] - return mask_interface - - def create_causal_mask( config: PretrainedConfig, input_embeds: torch.Tensor, attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, -) -> Optional[Union[torch.Tensor, "BlockMask"]]: +) -> Optional[Union[torch.Tensor, BlockMask]]: """ Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values` has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align @@ -685,8 +676,6 @@ def create_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. or_mask_function (`Callable`, optional): An optional mask function to combine with the causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the causal one, for example for image tokens handling. @@ -708,7 +697,7 @@ def create_causal_mask( batch_size, dtype = input_embeds.shape[0], input_embeds.dtype mask_factory_function = causal_mask_function - mask_interface = _get_mask_interface(config, output_attentions) + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it @@ -716,13 +705,13 @@ def create_causal_mask( # Allow slight deviations from causal mask if or_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: - raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + if not _is_torch_greater_or_equal_than_2_6: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False if and_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: - raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + if not _is_torch_greater_or_equal_than_2_6: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False @@ -747,10 +736,9 @@ def create_sliding_window_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, -) -> Optional[Union[torch.Tensor, "BlockMask"]]: +) -> Optional[Union[torch.Tensor, BlockMask]]: """ Create a sliding window causal mask based on the attention implementation used (stored in the config). This type of attention pattern was mostly democratized by Mistral. If `past_key_values` has an HybridCache structure, this @@ -770,8 +758,6 @@ def create_sliding_window_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. or_mask_function (`Callable`, optional): An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. @@ -797,7 +783,7 @@ def create_sliding_window_causal_mask( batch_size, dtype = input_embeds.shape[0], input_embeds.dtype mask_factory_function = sliding_window_causal_mask_function(sliding_window) - mask_interface = _get_mask_interface(config, output_attentions) + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it @@ -805,13 +791,13 @@ def create_sliding_window_causal_mask( # Allow slight deviations from sliding causal mask if or_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: - raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + if not _is_torch_greater_or_equal_than_2_6: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False if and_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: - raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + if not _is_torch_greater_or_equal_than_2_6: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False @@ -837,10 +823,9 @@ def create_chunked_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, -) -> Optional[Union[torch.Tensor, "BlockMask"]]: +) -> Optional[Union[torch.Tensor, BlockMask]]: """ Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type of attention pattern was mostly democratized by Llama4. If `past_key_values` has an HybridCache structure, this @@ -860,8 +845,6 @@ def create_chunked_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. or_mask_function (`Callable`, optional): An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. @@ -894,7 +877,7 @@ def create_chunked_causal_mask( batch_size, dtype = input_embeds.shape[0], input_embeds.dtype mask_factory_function = chunked_causal_mask_function(chunk_size) - mask_interface = _get_mask_interface(config, output_attentions) + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it @@ -902,13 +885,13 @@ def create_chunked_causal_mask( # Allow slight deviations from chunked causal mask if or_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: - raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + if not _is_torch_greater_or_equal_than_2_6: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False if and_mask_function is not None: - if not _is_torch_greater_or_equal_than_2_5: - raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5") + if not _is_torch_greater_or_equal_than_2_6: + raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False @@ -941,7 +924,6 @@ def create_masks_for_generate( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, **kwargs, @@ -963,8 +945,6 @@ def create_masks_for_generate( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. - output_attentions (`bool`, optional): - Whether we return the attention scores or not. By default `False`. or_mask_function (`Callable`, optional): An optional mask function to combine with the other mask function (by doing the union of both). This is useful to easily overlay another mask on top of the causal one, for example for image tokens handling. @@ -981,7 +961,6 @@ def create_masks_for_generate( "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, "or_mask_function": or_mask_function, "and_mask_function": and_mask_function, } diff --git a/src/transformers/model_debugging_utils.py b/src/transformers/model_debugging_utils.py index 009ac0c6b2d..d09cfa24a72 100644 --- a/src/transformers/model_debugging_utils.py +++ b/src/transformers/model_debugging_utils.py @@ -21,6 +21,8 @@ from contextlib import contextmanager, redirect_stdout from io import StringIO from typing import Optional +from safetensors.torch import save_file + from transformers.utils.import_utils import requires from .utils import is_torch_available @@ -65,64 +67,94 @@ def _dtensor_repr(x): return "DTensor(non-rank0)" -def _serialize_io(value): +def _serialize_tensor_like_io( + value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None +): + """ + Converts Tensors and DTensors to a JSON-serializable dictionary representation. + + Args: + value: Any Python object, often including torch Tensors, lists, dicts, etc. + debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files. + use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensor as the + `value` property in the asscoiated FULL_TENSORS.json file, or to store the full tensors in separate + SafeTensors file and store the relative path to that file in the `value` property in the dictionary. + path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full + tensor value if `use_repr=False`. + + Returns: + A nested Python structure (list, dict, or sanitized string) that is safe to json.dump. + """ + torch.set_printoptions(sci_mode=True) + + if use_repr: + value_out = _repr_to_list(value) + elif path_to_value: + if not path_to_value.endswith(".safetensors"): + path_to_value += ".safetensors" + + filepath = os.path.join(debug_path, path_to_value) if debug_path else path_to_value + save_file({"data": value.contiguous().detach().cpu()}, filepath) + value_out = f"./{path_to_value}" + else: + raise ValueError(f"{use_repr=} and {path_to_value=} cannot both be falsy.") + + out = { + "shape": repr(value.shape), + "dtype": repr(value.dtype), + "value": value_out, + } + if value.dtype in {torch.float16, torch.float32, torch.bfloat16}: + out.update( + { + "mean": _sanitize_repr_for_diff(repr(value.mean())), + "std": _sanitize_repr_for_diff(repr(value.std())), + "min": _sanitize_repr_for_diff(repr(value.min())), + "max": _sanitize_repr_for_diff(repr(value.max())), + } + ) + return out + + +def _serialize_io(value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None): """ Recursively build a JSON-serializable Python structure from `value`. - Tensors and DTensors become sanitized repr strings. + Tensors and DTensors become either sanitized repr strings, or are saved to disk as SafeTensors files and their + relative paths are recorded in the returned Python structure. Lists/tuples/dicts are recursed into. All memory addresses are replaced with a stable placeholder. Args: value: Any Python object, often including torch Tensors, lists, dicts, etc. + debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files. + use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the + `value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors + files and store the relative path to that file in the `value` property. + path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full + tensor value if `use_repr=False`. Returns: A nested Python structure (list, dict, or sanitized string) that is safe to json.dump. """ if isinstance(value, (list, tuple)): - return [_serialize_io(v) for v in value] + return [ + _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{i}") + for i, v in enumerate(value) + ] if isinstance(value, dict): - return {k: _serialize_io(v) for k, v in value.items()} + return { + k: _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{k}") + for k, v in value.items() + } if hasattr(value, "_local_tensor"): - # DTensor-like handling, just use local tensor attribute - torch.set_printoptions(sci_mode=True) - val_repr = _repr_to_list(value) - out = { - "shape": repr(value._local_tensor.shape), - "dtype": repr(value._local_tensor.dtype), - "value": val_repr, - } - if value._local_tensor.dtype in {torch.float16, torch.float32, torch.bfloat16}: - value = value._local_tensor.clone() - out.update( - { - "mean": _sanitize_repr_for_diff(repr(value.mean())), - "std": _sanitize_repr_for_diff(repr(value.std())), - "min": _sanitize_repr_for_diff(repr(value.min())), - "max": _sanitize_repr_for_diff(repr(value.max())), - } - ) - return out + return _serialize_tensor_like_io( + value._local_tensor, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value + ) if isinstance(value, torch.Tensor): - torch.set_printoptions(sci_mode=True) - val_repr = _repr_to_list(value) - out = { - "shape": repr(value.shape), - "dtype": repr(value.dtype), - "value": val_repr, - } - if value.dtype in {torch.float16, torch.float32, torch.bfloat16}: - out.update( - { - "mean": _sanitize_repr_for_diff(repr(value.mean())), - "std": _sanitize_repr_for_diff(repr(value.std())), - "min": _sanitize_repr_for_diff(repr(value.min())), - "max": _sanitize_repr_for_diff(repr(value.max())), - } - ) - return out + return _serialize_tensor_like_io(value, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value) return _sanitize_repr_for_diff(repr(value)) @@ -199,7 +231,7 @@ def log_model_debug_trace(debug_path, model): os.makedirs(debug_path, exist_ok=True) base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree") except Exception as e: - raise ValueError(f"Unexpected or existing debug_path={debug_path}. {e}") + raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e else: base = model._debugger_module_dump_name + "_debug_tree" @@ -240,6 +272,7 @@ def _attach_debugger_logic( model, debug_path: Optional[str] = ".", do_prune_layers: Optional[bool] = True, + use_repr: bool = True, ): """ Attaches a debugging wrapper to every module in the model. @@ -250,6 +283,9 @@ def _attach_debugger_logic( model (`PreTrainedModel`, `nn.Module`): Model to wrap. debug_path (`str`): Optional directory to dump debug JSON files. do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers. + use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the + `value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors + files and store the relative path to that file in the `value` property. """ class_name = model.__class__.__name__ @@ -258,6 +294,12 @@ def _attach_debugger_logic( model._debugger_model_call_stack = [] model._debugger_module_dump_name = class_name # used for final JSON filename + if debug_path: + try: + os.makedirs(debug_path, exist_ok=True) + except Exception as e: + raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e + def wrap_forward(module, full_path): orig_forward = module.forward @@ -268,7 +310,12 @@ def _attach_debugger_logic( dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0} node = { "module_path": full_path, - "inputs": _serialize_io(dict_inputs), + "inputs": _serialize_io( + dict_inputs, + debug_path=debug_path, + use_repr=use_repr, + path_to_value=f"{full_path}_inputs", + ), "outputs": None, "children": [], } @@ -280,7 +327,12 @@ def _attach_debugger_logic( if sum(1 for _ in module.named_children()) > 0: node["outputs"] = None else: - node["outputs"] = _serialize_io(out) + node["outputs"] = _serialize_io( + out, + debug_path=debug_path, + use_repr=use_repr, + path_to_value=f"{full_path}_outputs", + ) finished = model._debugger_model_call_stack.pop() # prune empty vertices here as well (mostly empty children nodes) @@ -307,7 +359,12 @@ def _attach_debugger_logic( if _is_rank_zero(): top_node = { "module_path": f"{class_name} (top-level)", - "inputs": _serialize_io({"args": inps, "kwargs": kws}), + "inputs": _serialize_io( + {"args": inps, "kwargs": kws}, + debug_path=debug_path, + use_repr=use_repr, + path_to_value=f"{class_name}_inputs", + ), "outputs": None, "children": [], } @@ -315,7 +372,12 @@ def _attach_debugger_logic( out = real_top_forward(*inps, **kws) if _is_rank_zero() and model._debugger_model_call_stack: - top_node["outputs"] = _serialize_io(out) + top_node["outputs"] = _serialize_io( + out, + debug_path=debug_path, + use_repr=use_repr, + path_to_value=f"{class_name}_outputs", + ) finished = model._debugger_model_call_stack.pop() model._call_tree["inputs"] = finished["inputs"] model._call_tree["outputs"] = finished["outputs"] @@ -335,11 +397,21 @@ def _attach_debugger_logic( @requires(backends=("torch",)) @contextmanager -def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_prune_layers: Optional[bool] = True): +def model_addition_debugger_context( + model, + debug_path: Optional[str] = None, + do_prune_layers: Optional[bool] = True, + use_repr: Optional[bool] = True, +): """ # Model addition debugger - context manager for model adders This context manager is a power user tool intended for model adders. - It tracks all forward calls within a model forward and logs a slice of each input and output on a nested Json. + + It tracks all forward calls within a model forward and logs a slice of each input and output on a nested JSON file. + If `use_repr=True` (the default), the JSON file will record a `repr()`-ized version of the tensors as a list of + strings. If `use_repr=False`, the full tensors will be stored in spearate SafeTensors files and the JSON file will + provide a relative path to that file. + To note, this context manager enforces `torch.no_grad()`. ## Usage @@ -348,10 +420,10 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_ ```python import torch + from PIL import Image - import requests - from transformers import LlavaProcessor, LlavaForConditionalGeneration - from transformers.model_debugging_utils import model_addition_debugger_context + from transformers import LlavaProcessor, LlavaForConditionalGeneration, model_addition_debugger_context + torch.random.manual_seed(673) # load pretrained model and processor @@ -376,7 +448,7 @@ def model_addition_debugger_context(model, debug_path: Optional[str] = None, do_ """ orig_forwards = {m: m.forward for _, m in model.named_modules()} orig_forwards[model] = model.forward - _attach_debugger_logic(model, debug_path, do_prune_layers) + _attach_debugger_logic(model, debug_path, do_prune_layers, use_repr) try: yield model finally: diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 678ee983da5..2f00d9b6c0e 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -427,9 +427,9 @@ class FlashAttentionKwargs(TypedDict, total=False): Keyword arguments for Flash Attention with Compile. Attributes: - cu_seq_lens_q (`torch.LongTensor`, *optional*) + cumulative_seqlens_q (`torch.LongTensor`, *optional*) Gets cumulative sequence length for query state. - cu_seq_lens_k (`torch.LongTensor`, *optional*) + cumulative_seqlens_k (`torch.LongTensor`, *optional*) Gets cumulative sequence length for key state. max_length_q (`int`, *optional*): Maximum sequence length for query state. @@ -437,7 +437,7 @@ class FlashAttentionKwargs(TypedDict, total=False): Maximum sequence length for key state. """ - cu_seq_lens_q: Optional[torch.LongTensor] - cu_seq_lens_k: Optional[torch.LongTensor] + cumulative_seqlens_q: Optional[torch.LongTensor] + cumulative_seqlens_k: Optional[torch.LongTensor] max_length_q: Optional[int] max_length_k: Optional[int] diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index ac7a47e29bc..7a200bdda96 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -1218,11 +1218,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): Register this class with a given auto class. This should only be used for custom models as the ones in the library are already mapped with an auto class. - - This API is experimental and may have some slight breaking changes in the next releases. - - Args: auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`): diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 24bdf4faa06..5b9f38e1bc2 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -78,6 +78,9 @@ def convert_tf_weight_name_to_pt_weight_name( tf_name = tf_name[len(name_scope) :] tf_name = tf_name.lstrip("/") tf_name = tf_name.replace(":0", "") # device ids + if (len(tf_name) > 2048 and "___" in tf_name) or tf_name.count("___") > 10: + # ReDOS check + raise ValueError("TF variable name is too long or contains too many ___ separators: " + tf_name) tf_name = re.sub( r"/[^/]*___([^/]*)/", r"/\1/", tf_name ) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 67aed15f0e5..ed7b018d89d 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -3229,11 +3229,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT Register this class with a given auto class. This should only be used for custom models as the ones in the library are already mapped with an auto class. - - This API is experimental and may have some slight breaking changes in the next releases. - - Args: auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 97e95b4161b..bd09c1ae57d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -27,6 +27,7 @@ import shutil import tempfile import warnings from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -57,9 +58,12 @@ from .generation import CompileConfig, GenerationConfig from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations.accelerate import find_tied_parameters, init_empty_weights from .integrations.deepspeed import _load_state_dict_into_zero3_model +from .integrations.eager_paged import eager_paged_attention_forward from .integrations.flash_attention import flash_attention_forward +from .integrations.flash_paged import paged_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward +from .integrations.sdpa_paged import sdpa_attention_paged_forward from .integrations.tensor_parallel import ( ALL_PARALLEL_STYLES, _get_parameter_tp_plan, @@ -165,6 +169,7 @@ if is_safetensors_available(): if is_kernels_available(): from kernels import get_kernel + logger = logging.get_logger(__name__) @@ -319,7 +324,8 @@ def get_torch_context_manager_or_global_device(): is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided. """ device_in_context = torch.tensor([]).device - default_device = torch.get_default_device() + # `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior + default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu") # This case means no context manager was used -> we still check if the default that was potentially set is not cpu if device_in_context == default_device: if default_device != torch.device("cpu"): @@ -866,6 +872,116 @@ def _load_state_dict_into_meta_model( return disk_offload_index, cpu_offload_index +def load_shard_file(args): + ( + shard_file, + state_dict, + disk_only_shard_files, + is_hqq_or_bnb, + is_quantized, + device_map, + hf_quantizer, + key_renaming_mapping, + weights_only, + model_to_load, + expected_keys, + reverse_key_renaming_mapping, + disk_offload_folder, + disk_offload_index, + cpu_offload_folder, + cpu_offload_index, + is_offloaded_safetensors, + keep_in_fp32_regex, + unexpected_keys, + device_mesh, + ) = args + + # Skip the load for shards that only contain disk-offloaded weights + if shard_file in disk_only_shard_files: + return [], disk_offload_index, cpu_offload_index + + map_location = "cpu" + if ( + shard_file.endswith(".safetensors") + and not is_hqq_or_bnb + and not (is_deepspeed_zero3_enabled() and not is_quantized) + ): + map_location = "meta" + elif ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and ( + hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] + or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig) + ) + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + + # If shard_file is "", we use the existing state_dict instead of loading it + if shard_file != "": + state_dict = load_state_dict( + shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only + ) + + # Fix the key names + state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} + + error_msgs = [] + + if is_deepspeed_zero3_enabled() and not is_quantized: + error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict) + # Skip it with fsdp on ranks other than 0 + elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): + disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + shard_file, + expected_keys, + reverse_key_renaming_mapping, + device_map=device_map, + disk_offload_folder=disk_offload_folder, + disk_offload_index=disk_offload_index, + cpu_offload_folder=cpu_offload_folder, + cpu_offload_index=cpu_offload_index, + hf_quantizer=hf_quantizer, + is_safetensors=is_offloaded_safetensors, + keep_in_fp32_regex=keep_in_fp32_regex, + unexpected_keys=unexpected_keys, + device_mesh=device_mesh, + ) + + return error_msgs, disk_offload_index, cpu_offload_index + + +def load_shard_files_with_threadpool(args_list): + num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) + + # Do not spawn anymore workers than you need + num_workers = min(len(args_list), num_workers) + + logger.info(f"Loading model weights in parallel with {num_workers} workers...") + + error_msgs = [] + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: + futures = [executor.submit(load_shard_file, arg) for arg in args_list] + for future in as_completed(futures): + result = future.result() + ( + _error_msgs, + disk_offload_index, + cpu_offload_index, + ) = result + + error_msgs += _error_msgs + + pbar.update(1) + + return error_msgs, disk_offload_index, cpu_offload_index + + def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: path, name = weights_name.rsplit(".", 1) @@ -889,6 +1005,7 @@ def _get_resolved_checkpoint_files( user_agent: dict, revision: str, commit_hash: Optional[str], + is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in transformers_explicit_filename: Optional[str] = None, ) -> Tuple[Optional[List[str]], Optional[Dict]]: """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the @@ -1085,7 +1202,10 @@ def _get_resolved_checkpoint_files( "_commit_hash": commit_hash, **has_file_kwargs, } - if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): + if ( + not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs) + and not is_remote_code + ): Thread( target=auto_conversion, args=(pretrained_model_name_or_path,), @@ -1457,7 +1577,8 @@ def _find_mismatched_keys( # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. # Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights. if not ( - new_state_dict[key].shape[-1] == 1 + is_quantized + and new_state_dict[key].shape[-1] == 1 and new_state_dict[key].numel() * 2 == model_state_dict[key].numel() ): mismatched_keys.append(key) @@ -1963,7 +2084,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if plan := getattr(module, "_tp_plan", None): self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) - if self._tp_plan is not None and is_torch_greater_or_equal("2.3"): + if self._tp_plan is not None and is_torch_greater_or_equal("2.5"): for _, v in self._tp_plan.items(): if v not in ALL_PARALLEL_STYLES: raise ValueError( @@ -2542,7 +2663,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi def smart_apply(self, fn): for module in self.children(): # We found a sub-model: recursively dispatch its own init function now! - if hasattr(module, "_init_weights"): + if isinstance(module, PreTrainedModel): module.smart_apply(module._initialize_weights) else: module.smart_apply(fn) @@ -3532,7 +3653,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi for key, value in state_dict.items(): for pattern, replacement in reverse_key_mapping.items(): replacement = replacement.lstrip("^") # strip off un-needed chars and patterns - replacement = re.sub(r"\(.*?\)", "", pattern) + replacement = re.sub(r"\(.*\)", "", replacement) key, n_replace = re.subn(pattern, replacement, key) # Early exit of the loop if n_replace > 0: @@ -4434,6 +4555,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi user_agent=user_agent, revision=revision, commit_hash=commit_hash, + is_remote_code=cls._auto_class is not None, transformers_explicit_filename=transformers_explicit_filename, ) @@ -4969,9 +5091,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi cpu_offload_folder = tempfile.mkdtemp() cpu_offload_index = {} - # For nice tqdm bars - if checkpoint_files is not None and len(checkpoint_files) > 1: - checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") # To be able to iterate, even if we don't use it if the state_dict is already provided elif state_dict is not None: checkpoint_files = [""] @@ -4989,64 +5108,48 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi expanded_device_map = expand_device_map(device_map, expected_keys) caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer) + # Prepare and compatabilize arguments for serial and parallel shard loading + args_list = [ + ( + shard_file, + state_dict, + disk_only_shard_files, + is_hqq_or_bnb, + is_quantized, + device_map, + hf_quantizer, + key_renaming_mapping, + weights_only, + model_to_load, + expected_keys, + reverse_key_renaming_mapping, + disk_offload_folder, + disk_offload_index, + cpu_offload_folder, + cpu_offload_index, + is_offloaded_safetensors, + keep_in_fp32_regex, + unexpected_keys, + device_mesh, + ) + for shard_file in checkpoint_files + ] + error_msgs = [] - # Iterate on all the shards to load the weights - for shard_file in checkpoint_files: - # Skip the load for shards that only contain disk-offloaded weights - if shard_file in disk_only_shard_files: - continue - map_location = "cpu" - if ( - shard_file.endswith(".safetensors") - and not is_hqq_or_bnb - and not (is_deepspeed_zero3_enabled() and not is_quantized) - ): - map_location = "meta" - elif ( - device_map is not None - and hf_quantizer is not None - and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and ( - hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] - or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig) - ) - ): - map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + if ( + os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES + and not is_deepspeed_zero3_enabled() + ): + _error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list) + error_msgs += _error_msgs + else: + if len(args_list) > 1: + args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") - # If shard_file is "", we use the existing state_dict instead of loading it - if shard_file != "": - state_dict = load_state_dict( - shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only - ) - - # Fix the key names - state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} - - if is_deepspeed_zero3_enabled() and not is_quantized: - error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict) - # Skip it with fsdp on ranks other than 0 - elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): - disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( - model_to_load, - state_dict, - shard_file, - expected_keys, - reverse_key_renaming_mapping, - device_map=device_map, - disk_offload_folder=disk_offload_folder, - disk_offload_index=disk_offload_index, - cpu_offload_folder=cpu_offload_folder, - cpu_offload_index=cpu_offload_index, - hf_quantizer=hf_quantizer, - is_safetensors=is_offloaded_safetensors, - keep_in_fp32_regex=keep_in_fp32_regex, - unexpected_keys=unexpected_keys, - device_mesh=device_mesh, - ) - - # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop - del state_dict + for args in args_list: + _error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args) + error_msgs += _error_msgs # Adjust offloaded weights name and save if needed if disk_offload_index is not None and len(disk_offload_index) > 0: @@ -5224,11 +5327,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi Register this class with a given auto class. This should only be used for custom models as the ones in the library are already mapped with an auto class. - - This API is experimental and may have some slight breaking changes in the next releases. - - Args: auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`): @@ -5478,8 +5577,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi def get_parameter_or_buffer(self, target: str): """ Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines - `get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a - leaf of the model. + `get_parameter()` and `get_buffer()` in a single handy function. If the target is an `_extra_state` attribute, + it will return the extra state provided by the module. Note that it only work if `target` is a leaf of the model. """ try: return self.get_parameter(target) @@ -5489,7 +5588,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi return self.get_buffer(target) except AttributeError: pass - raise AttributeError(f"`{target}` is neither a parameter nor a buffer.") + module, param_name = get_module_from_name(self, target) + if ( + param_name == "_extra_state" + and getattr(module.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + return module.get_extra_state() + + raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.") PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) @@ -6088,7 +6195,10 @@ class AttentionInterface(GeneralInterface): _global_mapping = { "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, + "paged_attention": paged_attention_forward, "sdpa": sdpa_attention_forward, + "sdpa_paged": sdpa_attention_paged_forward, + "eager_paged": eager_paged_attention_forward, } diff --git a/src/transformers/models/albert/tokenization_albert_fast.py b/src/transformers/models/albert/tokenization_albert_fast.py index 6e7b110b0af..05712eeb6eb 100644 --- a/src/transformers/models/albert/tokenization_albert_fast.py +++ b/src/transformers/models/albert/tokenization_albert_fast.py @@ -130,10 +130,6 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast): self.keep_accents = keep_accents self.vocab_file = vocab_file - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: diff --git a/src/transformers/models/aria/image_processing_aria.py b/src/transformers/models/aria/image_processing_aria.py index 54a2ec9488c..eb9badef1ef 100644 --- a/src/transformers/models/aria/image_processing_aria.py +++ b/src/transformers/models/aria/image_processing_aria.py @@ -500,5 +500,26 @@ class AriaImageProcessor(BaseImageProcessor): ] return patches + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number of image patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of patches per image. + """ + split_image = images_kwargs.get("split_image", None) or self.split_image + max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size + + resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions) + num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size + return num_patches + __all__ = ["AriaImageProcessor"] diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index cd794846275..8f552cfc815 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -549,15 +549,8 @@ class AriaTextAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -812,7 +805,6 @@ class AriaTextModel(AriaTextPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -1227,6 +1219,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 5afc05e9159..561f94e4e73 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -34,7 +34,7 @@ from ...image_utils import ( ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import PreTrainedModel -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils import PreTokenizedInput, TextInput from ...utils import LossKwargs, TensorType, auto_docstring, can_return_tuple, logging from ...utils.import_utils import is_torch_available @@ -884,11 +884,33 @@ class AriaImageProcessor(BaseImageProcessor): ] return patches + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number of image patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of patches per image. + """ + split_image = images_kwargs.get("split_image", None) or self.split_image + max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size + + resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions) + num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size + return num_patches + class AriaProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": False, + "return_mm_token_type_ids": False, }, "images_kwargs": { "max_image_size": 980, @@ -914,7 +936,6 @@ class AriaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "size_conversion"] image_processor_class = "AriaImageProcessor" tokenizer_class = "AutoTokenizer" @@ -978,10 +999,7 @@ class AriaProcessor(ProcessorMixin): raise ValueError("Invalid input text. Please provide a string, or a list of strings") if images is not None: - image_inputs = self.image_processor( - images, - **output_kwargs["images_kwargs"], - ) + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) # expand the image_token according to the num_crops and tokens per image tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]] prompt_strings = [] @@ -995,11 +1013,44 @@ class AriaProcessor(ProcessorMixin): prompt_strings = text return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) - text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None) self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`List[List[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + + max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please diff --git a/src/transformers/models/aria/processing_aria.py b/src/transformers/models/aria/processing_aria.py index 7d307624755..7ecf3af670c 100644 --- a/src/transformers/models/aria/processing_aria.py +++ b/src/transformers/models/aria/processing_aria.py @@ -20,9 +20,11 @@ # limitations under the License. from typing import Dict, List, Optional, Union +import numpy as np + from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils import PreTokenizedInput, TextInput from ...utils import TensorType from ..auto import AutoTokenizer @@ -32,6 +34,7 @@ class AriaProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": False, + "return_mm_token_type_ids": False, }, "images_kwargs": { "max_image_size": 980, @@ -57,7 +60,6 @@ class AriaProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "size_conversion"] image_processor_class = "AriaImageProcessor" tokenizer_class = "AutoTokenizer" @@ -121,10 +123,7 @@ class AriaProcessor(ProcessorMixin): raise ValueError("Invalid input text. Please provide a string, or a list of strings") if images is not None: - image_inputs = self.image_processor( - images, - **output_kwargs["images_kwargs"], - ) + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) # expand the image_token according to the num_crops and tokens per image tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]] prompt_strings = [] @@ -138,11 +137,44 @@ class AriaProcessor(ProcessorMixin): prompt_strings = text return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) - text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None) self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`List[List[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + + max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 3b4ead50134..b32f2b711f1 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -420,17 +420,23 @@ class _BaseAutoModelClass: trust_remote_code = kwargs.pop("trust_remote_code", None) has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map has_local_code = type(config) in cls._model_mapping.keys() - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, config._name_or_path, has_local_code, has_remote_code - ) + if has_remote_code: + class_ref = config.auto_map[cls.__name__] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo + ) if has_remote_code and trust_remote_code: - class_ref = config.auto_map[cls.__name__] if "--" in class_ref: repo_id, class_ref = class_ref.split("--") else: repo_id = config.name_or_path model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) + model_class.register_for_auto_class(auto_class=cls) cls.register(config.__class__, model_class, exist_ok=True) _ = kwargs.pop("code_revision", None) model_class = add_generation_mixin_to_remote_model(model_class) @@ -545,8 +551,17 @@ class _BaseAutoModelClass: has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map has_local_code = type(config) in cls._model_mapping.keys() + upstream_repo = None + if has_remote_code: + class_ref = config.auto_map[cls.__name__] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] trust_remote_code = resolve_trust_remote_code( - trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code + trust_remote_code, + pretrained_model_name_or_path, + has_local_code, + has_remote_code, + upstream_repo=upstream_repo, ) kwargs["trust_remote_code"] = trust_remote_code @@ -554,12 +569,12 @@ class _BaseAutoModelClass: kwargs["adapter_kwargs"] = adapter_kwargs if has_remote_code and trust_remote_code: - class_ref = config.auto_map[cls.__name__] model_class = get_class_from_dynamic_module( class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs ) _ = hub_kwargs.pop("code_revision", None) cls.register(config.__class__, model_class, exist_ok=True) + model_class.register_for_auto_class(auto_class=cls) model_class = add_generation_mixin_to_remote_model(model_class) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 6c94eef83eb..726d173ba10 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -15,7 +15,6 @@ """Auto Config class.""" import importlib -import os import re import warnings from collections import OrderedDict @@ -1160,17 +1159,21 @@ class AutoConfig: config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"] has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code - ) + if has_remote_code: + class_ref = config_dict["auto_map"]["AutoConfig"] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) if has_remote_code and trust_remote_code: - class_ref = config_dict["auto_map"]["AutoConfig"] config_class = get_class_from_dynamic_module( class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs ) - if os.path.isdir(pretrained_model_name_or_path): - config_class.register_for_auto_class() + config_class.register_for_auto_class() return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) elif "model_type" in config_dict: try: diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 86dc8703c42..a0f171af245 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -371,17 +371,21 @@ class AutoFeatureExtractor: has_remote_code = feature_extractor_auto_map is not None has_local_code = feature_extractor_class is not None or type(config) in FEATURE_EXTRACTOR_MAPPING - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code - ) + if has_remote_code: + if "--" in feature_extractor_auto_map: + upstream_repo = feature_extractor_auto_map.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) if has_remote_code and trust_remote_code: feature_extractor_class = get_class_from_dynamic_module( feature_extractor_auto_map, pretrained_model_name_or_path, **kwargs ) _ = kwargs.pop("code_revision", None) - if os.path.isdir(pretrained_model_name_or_path): - feature_extractor_class.register_for_auto_class() + feature_extractor_class.register_for_auto_class() return feature_extractor_class.from_dict(config_dict, **kwargs) elif feature_extractor_class is not None: return feature_extractor_class.from_dict(config_dict, **kwargs) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 76f1ca87ca0..52c009a23e4 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -541,26 +541,29 @@ class AutoImageProcessor: has_remote_code = image_processor_auto_map is not None has_local_code = image_processor_class is not None or type(config) in IMAGE_PROCESSOR_MAPPING - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code - ) - - if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple): - # In some configs, only the slow image processor class is stored - image_processor_auto_map = (image_processor_auto_map, None) + if has_remote_code: + if image_processor_auto_map is not None and not isinstance(image_processor_auto_map, tuple): + # In some configs, only the slow image processor class is stored + image_processor_auto_map = (image_processor_auto_map, None) + if use_fast and image_processor_auto_map[1] is not None: + class_ref = image_processor_auto_map[1] + else: + class_ref = image_processor_auto_map[0] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) if has_remote_code and trust_remote_code: if not use_fast and image_processor_auto_map[1] is not None: _warning_fast_image_processor_available(image_processor_auto_map[1]) - if use_fast and image_processor_auto_map[1] is not None: - class_ref = image_processor_auto_map[1] - else: - class_ref = image_processor_auto_map[0] image_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) _ = kwargs.pop("code_revision", None) - if os.path.isdir(pretrained_model_name_or_path): - image_processor_class.register_for_auto_class() + image_processor_class.register_for_auto_class() return image_processor_class.from_dict(config_dict, **kwargs) elif image_processor_class is not None: return image_processor_class.from_dict(config_dict, **kwargs) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b7b97d88f8c..9a0af0e9849 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -2106,6 +2106,7 @@ __all__ = [ "AutoModelForTableQuestionAnswering", "AutoModelForTextToSpectrogram", "AutoModelForTextToWaveform", + "AutoModelForTimeSeriesPrediction", "AutoModelForTokenClassification", "AutoModelForUniversalSegmentation", "AutoModelForVideoClassification", diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index de14dab53f6..e5a675c6da7 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -17,7 +17,6 @@ import importlib import inspect import json -import os import warnings from collections import OrderedDict @@ -359,17 +358,21 @@ class AutoProcessor: has_remote_code = processor_auto_map is not None has_local_code = processor_class is not None or type(config) in PROCESSOR_MAPPING - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code - ) + if has_remote_code: + if "--" in processor_auto_map: + upstream_repo = processor_auto_map.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) if has_remote_code and trust_remote_code: processor_class = get_class_from_dynamic_module( processor_auto_map, pretrained_model_name_or_path, **kwargs ) _ = kwargs.pop("code_revision", None) - if os.path.isdir(pretrained_model_name_or_path): - processor_class.register_for_auto_class() + processor_class.register_for_auto_class() return processor_class.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs ) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 34251174893..e2e21ff8247 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -989,19 +989,23 @@ class AutoTokenizer: or tokenizer_class_from_name(config_tokenizer_class + "Fast") is not None ) ) - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code - ) - - if has_remote_code and trust_remote_code: + if has_remote_code: if use_fast and tokenizer_auto_map[1] is not None: class_ref = tokenizer_auto_map[1] else: class_ref = tokenizer_auto_map[0] + if "--" in class_ref: + upstream_repo = class_ref.split("--")[0] + else: + upstream_repo = None + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code, upstream_repo + ) + + if has_remote_code and trust_remote_code: tokenizer_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) _ = kwargs.pop("code_revision", None) - if os.path.isdir(pretrained_model_name_or_path): - tokenizer_class.register_for_auto_class() + tokenizer_class.register_for_auto_class() return tokenizer_class.from_pretrained( pretrained_model_name_or_path, *inputs, trust_remote_code=trust_remote_code, **kwargs ) diff --git a/src/transformers/models/auto/video_processing_auto.py b/src/transformers/models/auto/video_processing_auto.py index e7d08239fe9..507930df720 100644 --- a/src/transformers/models/auto/video_processing_auto.py +++ b/src/transformers/models/auto/video_processing_auto.py @@ -339,8 +339,7 @@ class AutoVideoProcessor: class_ref = video_processor_auto_map video_processor_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) _ = kwargs.pop("code_revision", None) - if os.path.isdir(pretrained_model_name_or_path): - video_processor_class.register_for_auto_class() + video_processor_class.register_for_auto_class() return video_processor_class.from_dict(config_dict, **kwargs) elif video_processor_class is not None: return video_processor_class.from_dict(config_dict, **kwargs) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 9f7f1515a27..0a41692f69c 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -26,14 +26,21 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput from ...modeling_utils import PreTrainedModel from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_autoformer import AutoformerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -904,6 +911,29 @@ class AutoformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoder with TimeSeriesTransformer->Autoformer,TimeSeries->Autoformer class AutoformerEncoder(AutoformerPreTrainedModel): @@ -983,10 +1013,10 @@ class AutoformerEncoder(AutoformerPreTrainedModel): hidden_states = self.layernorm_embedding(hidden_states + embed_pos) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index a851d4d0a0f..e074d4b1193 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -389,6 +389,12 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/aya_vision/processing_aya_vision.py b/src/transformers/models/aya_vision/processing_aya_vision.py index 3b9afecda50..be3f04a1819 100644 --- a/src/transformers/models/aya_vision/processing_aya_vision.py +++ b/src/transformers/models/aya_vision/processing_aya_vision.py @@ -13,22 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. - from typing import List, Optional, Union -from transformers.processing_utils import ( - ImagesKwargs, - ProcessingKwargs, - ProcessorMixin, - Unpack, -) -from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +import numpy as np from ...image_processing_utils import BatchFeature -from ...image_utils import ( - ImageInput, - make_flat_list_of_images, -) +from ...image_utils import ImageInput, make_flat_list_of_images +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput class AyaVisionImagesKwargs(ImagesKwargs, total=False): @@ -43,6 +35,7 @@ class AyaVisionProcessorKwargs(ProcessingKwargs, total=False): "text_kwargs": { "padding_side": "left", "padding": True, + "return_mm_token_type_ids": False, }, "images_kwargs": { "crop_to_patches": True, @@ -85,19 +78,6 @@ class AyaVisionProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "image_token", - "patch_size", - "img_size", - "downsample_factor", - "start_of_img_token", - "end_of_img_token", - "img_patch_token", - "img_line_break_token", - "tile_token", - "tile_global_token", - ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" @@ -121,7 +101,6 @@ class AyaVisionProcessor(ProcessorMixin): super().__init__(image_processor, tokenizer, chat_template=chat_template) self.image_token = image_token - self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) self.patch_size = patch_size * downsample_factor self.img_size = img_size @@ -131,6 +110,10 @@ class AyaVisionProcessor(ProcessorMixin): self.img_line_break_token = img_line_break_token self.tile_token = tile_token self.tile_global_token = tile_global_token + self.image_token_id = tokenizer.convert_tokens_to_ids(self.img_patch_token) + self.image_ids = tokenizer.convert_tokens_to_ids( + [img_patch_token, tile_token, tile_global_token, start_of_img_token, end_of_img_token] + ) def _prompt_split_image(self, num_patches): """ @@ -226,11 +209,49 @@ class AyaVisionProcessor(ProcessorMixin): text = processed_text return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) - text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) - self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`List[List[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = AyaVisionProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + + token_per_patch = (self.img_size // self.patch_size) ** 2 + num_image_tokens = [ + token_per_patch + 3 + sum(token_per_patch + 1 for _ in range(1, num_patches)) + for num_patches in num_image_patches + ] # Add +3 and +1 for BOI/EOI and image tile tokens + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 11742b1a321..1b8e12d1c3b 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -24,7 +24,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Tuple, TypedDict, Union import torch @@ -38,6 +37,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -313,15 +313,8 @@ class BambaAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -945,7 +938,7 @@ class BambaRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class BambaDecoderLayer(nn.Module): +class BambaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"): super().__init__() @@ -1161,30 +1154,17 @@ class BambaModel(BambaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **kwargs), - hidden_states, - layer_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=layer_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 7e0090b3945..9db52ebfbc5 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -19,7 +19,6 @@ # limitations under the License. """PyTorch Bamba model.""" -from functools import partial from typing import Optional, Tuple, TypedDict, Union import torch @@ -928,30 +927,17 @@ class BambaModel(BambaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **kwargs), - hidden_states, - layer_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=layer_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 01f7f19a79e..2442baa2436 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -17,7 +17,7 @@ import copy import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -32,7 +32,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -42,7 +42,8 @@ from ...modeling_outputs import ( Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -53,13 +54,7 @@ from .configuration_bart import BartConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -119,6 +114,36 @@ class BartScaledWordEmbedding(nn.Embedding): return super().forward(input_ids) * self.embed_scale +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class BartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -170,17 +195,25 @@ class BartAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -201,8 +234,8 @@ class BartAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -214,297 +247,27 @@ class BartAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class BartFlashAttention2(BartAttention): - """ - Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # BartFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError( - "BartSdpaAttention2 attention does not support `output_attentions`. " - "Use the argument `attn_implementation='eager'` when loading the model." - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value - - -class BartSdpaAttention(BartAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - causal_mask = None - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -BART_ATTENTION_CLASSES = { - "eager": BartAttention, - "sdpa": BartSdpaAttention, - "flash_attention_2": BartFlashAttention2, -} + return attn_output, attn_weights, past_key_value class BartEncoderLayer(nn.Module): @@ -512,7 +275,7 @@ class BartEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BartAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -583,7 +346,7 @@ class BartDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BartAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -597,7 +360,7 @@ class BartDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = BartAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -671,6 +434,7 @@ class BartDecoderLayer(nn.Module): layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -730,6 +494,7 @@ class BartPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True @@ -757,23 +522,53 @@ class BartPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -782,7 +577,7 @@ class BartPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -816,7 +611,6 @@ class BartPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -882,6 +676,41 @@ class BartPreTrainedModel(PreTrainedModel): return causal_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class PretrainedBartModel(BartPreTrainedModel): def __init_subclass__(self): @@ -932,8 +761,6 @@ class BartEncoder(BartPreTrainedModel): embed_dim, ) self.layers = nn.ModuleList([BartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -1019,18 +846,10 @@ class BartEncoder(BartPreTrainedModel): hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - if self._use_flash_attention_2: - attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1116,8 +935,6 @@ class BartDecoder(BartPreTrainedModel): config.d_model, ) self.layers = nn.ModuleList([BartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -1232,12 +1049,18 @@ class BartDecoder(BartPreTrainedModel): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) # initialize `past_key_values` return_legacy_cache = False @@ -1267,38 +1090,25 @@ class BartDecoder(BartPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) - causal_mask = self._update_causal_mask( + + attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, ) - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=seq_length, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) - # embed positions - position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position) - position_ids = position_ids.to(inputs_embeds.device) + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + positions = positions.to(inputs_embeds.device) - hidden_states = inputs_embeds + position_ids + hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1331,7 +1141,7 @@ class BartDecoder(BartPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1344,7 +1154,7 @@ class BartDecoder(BartPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/barthez/tokenization_barthez_fast.py b/src/transformers/models/barthez/tokenization_barthez_fast.py index a1d95ef03e4..70c301eee91 100644 --- a/src/transformers/models/barthez/tokenization_barthez_fast.py +++ b/src/transformers/models/barthez/tokenization_barthez_fast.py @@ -122,10 +122,6 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast): self.vocab_file = vocab_file - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: diff --git a/src/transformers/models/big_bird/tokenization_big_bird_fast.py b/src/transformers/models/big_bird/tokenization_big_bird_fast.py index 83f2fac07fa..18383a7ddb1 100644 --- a/src/transformers/models/big_bird/tokenization_big_bird_fast.py +++ b/src/transformers/models/big_bird/tokenization_big_bird_fast.py @@ -119,10 +119,6 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast): self.vocab_file = vocab_file - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 4ff34b9ef25..d49d4e65bd7 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -29,7 +29,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,20 +41,14 @@ from ...modeling_outputs import ( Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( - auto_docstring, - is_torch_flex_attn_available, - is_torchdynamo_compiling, - logging, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging from .configuration_bigbird_pegasus import BigBirdPegasusConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -1179,6 +1175,37 @@ class BigBirdPegasusEncoderAttention(nn.Module): return outputs +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with BartConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder class BigBirdPegasusDecoderAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1231,17 +1258,25 @@ class BigBirdPegasusDecoderAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -1262,8 +1297,8 @@ class BigBirdPegasusDecoderAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -1275,66 +1310,27 @@ class BigBirdPegasusDecoderAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class BigBirdPegasusEncoderLayer(nn.Module): @@ -1434,6 +1430,7 @@ class BigBirdPegasusDecoderLayer(nn.Module): dropout=config.attention_dropout, is_decoder=True, bias=config.use_bias, + config=config, layer_idx=layer_idx, ) self.dropout = config.dropout @@ -1447,6 +1444,7 @@ class BigBirdPegasusDecoderLayer(nn.Module): dropout=config.attention_dropout, is_decoder=True, bias=config.use_bias, + config=config, layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -1510,7 +1508,6 @@ class BigBirdPegasusDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -1602,23 +1599,32 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1627,7 +1633,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1661,7 +1667,6 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -1727,6 +1732,42 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): """ @@ -2172,9 +2213,13 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -2207,28 +2252,26 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) - causal_mask = self._update_causal_mask( + + attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions - position_ids = cache_position.unsqueeze(0) - position_ids = self.embed_positions( - (batch_size, seq_length), past_key_values_length, position_ids=position_ids - ) - position_ids = position_ids.to(inputs_embeds.device) - hidden_states = inputs_embeds + position_ids + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers @@ -2258,7 +2301,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -2271,7 +2314,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -2979,7 +3022,7 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index d93b6f6ae2d..f12eeac6973 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_biogpt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved. # @@ -12,56 +18,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch BioGPT model.""" import math -from typing import Optional, Tuple, Union +from functools import partial +from typing import Callable, Optional, Tuple, Union import torch -import torch.utils.checkpoint -from torch import nn +import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, -) +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( - auto_docstring, - is_torch_flex_attn_available, - is_torchdynamo_compiling, - logging, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, is_torch_flex_attn_available, logging from .configuration_biogpt import BioGptConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) -# copied from transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding with OPT->BioGpt -# TODO @ArthurZucker bring copied from back class BioGptLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """ def __init__(self, num_embeddings: int, embedding_dim: int): - # BioGpt is set up so that if padding_idx is specified then offset the embedding ids by 2 + # BIOGPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) @@ -70,22 +66,19 @@ class BioGptLearnedPositionalEmbedding(nn.Embedding): self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, - position_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, ): """`input_ids_shape` is expected to be [bsz x seqlen].""" + if position_ids is None: - attention_mask = attention_mask.long() - - # create positions depending on attention_mask - positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 - + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() # cut positions if `past_key_values_length` is > 0 - position_ids = positions[:, past_key_values_length:] + position_ids = position_ids[:, past_key_values_length:] return super().forward(position_ids + self.offset) -# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt class BioGptScaledWordEmbedding(nn.Embedding): """ This module overrides nn.Embeddings' forward by multiplying with embeddings scale. @@ -99,7 +92,36 @@ class BioGptScaledWordEmbedding(nn.Embedding): return super().forward(input_ids) * self.embed_scale -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BioGpt +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class BioGptAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -151,17 +173,25 @@ class BioGptAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -182,8 +212,8 @@ class BioGptAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -195,178 +225,27 @@ class BioGptAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->BioGpt -class BioGptSdpaAttention(BioGptAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "BioGptModel is using BioGptSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - causal_mask = None - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value - - -BIOGPT_ATTENTION_CLASSES = { - "eager": BioGptAttention, - "sdpa": BioGptSdpaAttention, -} + return attn_output, attn_weights, past_key_value class BioGptDecoderLayer(nn.Module): @@ -374,12 +253,13 @@ class BioGptDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = BIOGPT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BioGptAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_probs_dropout_prob, is_decoder=True, is_causal=True, + config=config, layer_idx=layer_idx, ) self.dropout = config.hidden_dropout_prob @@ -400,7 +280,9 @@ class BioGptDecoderLayer(nn.Module): past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -431,7 +313,9 @@ class BioGptDecoderLayer(nn.Module): attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + position_ids=position_ids, cache_position=cache_position, + **flash_attn_kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -462,7 +346,9 @@ class BioGptPreTrainedModel(PreTrainedModel): config_class = BioGptConfig base_model_prefix = "biogpt" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True @@ -482,23 +368,32 @@ class BioGptPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -507,7 +402,7 @@ class BioGptPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -541,7 +436,6 @@ class BioGptPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -628,7 +522,6 @@ class BioGptModel(BioGptPreTrainedModel): self.layer_norm = nn.LayerNorm(self.embed_dim) self.gradient_checkpointing = False - self._use_sdpa = config._attn_implementation == "sdpa" # Initialize weights and apply final processing self.post_init() @@ -652,7 +545,7 @@ class BioGptModel(BioGptPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, - **kwargs, # NOOP kwargs, for now + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -663,18 +556,24 @@ class BioGptModel(BioGptPreTrainedModel): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." ) use_cache = False @@ -696,7 +595,7 @@ class BioGptModel(BioGptPreTrainedModel): past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - if attention_mask is None and not is_torchdynamo_compiling(): + if attention_mask is None: # required mask seq length can be calculated via length of past cache mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) @@ -706,27 +605,37 @@ class BioGptModel(BioGptPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) # embed positions if position_ids is None: - position_ids = cache_position.unsqueeze(0) + # position_ids = cache_position.unsqueeze(0) + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_seen_tokens` is > 0 + position_ids = position_ids[:, past_key_values_length:] - position_ids = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) - - hidden_states = inputs_embeds + position_ids + positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) + hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = None - next_decoder_cache = None + next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -739,13 +648,14 @@ class BioGptModel(BioGptPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, head_mask[idx] if head_mask is not None else None, None, output_attentions, use_cache, + position_ids, cache_position, ) else: @@ -756,7 +666,9 @@ class BioGptModel(BioGptPreTrainedModel): past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + position_ids=position_ids, cache_position=cache_position, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -792,6 +704,9 @@ class BioGptModel(BioGptPreTrainedModel): ) +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @auto_docstring( custom_intro=""" BioGPT Model with a `language modeling` head on top for CLM fine-tuning. @@ -830,7 +745,7 @@ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, - **kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -852,6 +767,7 @@ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) sequence_output = outputs[0] @@ -916,9 +832,11 @@ class BioGptForTokenClassification(BioGptPreTrainedModel): inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -935,9 +853,11 @@ class BioGptForTokenClassification(BioGptPreTrainedModel): head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, + position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] @@ -1004,9 +924,11 @@ class BioGptForSequenceClassification(BioGptPreTrainedModel): inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1023,9 +945,11 @@ class BioGptForSequenceClassification(BioGptPreTrainedModel): head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, + position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py new file mode 100644 index 00000000000..78d6da134b8 --- /dev/null +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -0,0 +1,850 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BioGPT model.""" + +import math +from functools import partial +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + auto_docstring, + is_torch_flex_attn_available, + logger, +) +from ..bart.modeling_bart import ( + BartAttention, + BartDecoderLayer, + BartScaledWordEmbedding, +) +from ..opt.modeling_opt import OPTLearnedPositionalEmbedding +from .configuration_biogpt import BioGptConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding): + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + super().forward(attention_mask, past_key_values_length, position_ids) + + +class BioGptScaledWordEmbedding(BartScaledWordEmbedding): + pass + + +class BioGptAttention(BartAttention): + pass + + +class BioGptDecoderLayer(BartDecoderLayer): + def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None): + super().__init__(config) + self.embed_dim = config.hidden_size + + self.self_attn = BioGptAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_probs_dropout_prob, + is_decoder=True, + is_causal=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.hidden_dropout_prob + self.activation_fn = ACT2FN[config.hidden_act] + + self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim) + + del self.encoder_attn + del self.encoder_attn_layer_norm + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + position_ids=position_ids, + cache_position=cache_position, + **flash_attn_kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + +@auto_docstring +class BioGptPreTrainedModel(PreTrainedModel): + config_class = BioGptConfig + base_model_prefix = "biogpt" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring +class BioGptModel(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.config = config + self.layerdrop = config.layerdrop + self.dropout = config.hidden_dropout_prob + self.embed_dim = config.hidden_size + self.padding_idx = config.pad_token_id + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + self.embed_tokens = BioGptScaledWordEmbedding( + config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale + ) + self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) + + self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.layer_norm = nn.LayerNorm(self.embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize past_key_values + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None: + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + ) + + # embed positions + if position_ids is None: + # position_ids = cache_position.unsqueeze(0) + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_seen_tokens` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + position_ids, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.layer_norm(hidden_states) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@auto_docstring( + custom_intro=""" + BioGPT Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["output_projection.weight"] + + def __init__(self, config): + super().__init__(config) + + self.biogpt = BioGptModel(config) + self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.output_projection + + def set_output_embeddings(self, new_embeddings): + self.output_projection = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.biogpt( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + sequence_output = outputs[0] + prediction_scores = self.output_projection(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@auto_docstring +class BioGptForTokenClassification(BioGptPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.biogpt = BioGptModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + else: + classifier_dropout = config.hidden_dropout_prob + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The BioGpt Model transformer with a sequence classification head on top (linear layer). + + [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it is required to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class BioGptForSequenceClassification(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.biogpt = BioGptModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_length = -1 + else: + if input_ids is not None: + sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_length = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.biogpt.embed_tokens + + def set_input_embeddings(self, value): + self.biogpt.embed_tokens = value + + +__all__ = [ + "BioGptForCausalLM", + "BioGptForTokenClassification", + "BioGptForSequenceClassification", + "BioGptModel", + "BioGptPreTrainedModel", +] diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index e98f9ed1162..661a3c9bb60 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -205,13 +205,7 @@ class BitNetAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -425,7 +419,6 @@ class BitNetModel(BitNetPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/bitnet/modular_bitnet.py b/src/transformers/models/bitnet/modular_bitnet.py index 0c0d133cb5d..c57b7217f1d 100644 --- a/src/transformers/models/bitnet/modular_bitnet.py +++ b/src/transformers/models/bitnet/modular_bitnet.py @@ -85,13 +85,7 @@ class BitNetAttention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 8eb282ac6fa..4c001a35446 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -18,7 +18,7 @@ import copy import math import os import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -31,7 +31,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,7 +41,8 @@ from ...modeling_outputs import ( Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -51,9 +54,7 @@ from .configuration_blenderbot import BlenderbotConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -110,6 +111,37 @@ class BlenderbotScaledWordEmbedding(nn.Embedding): return super().forward(input_ids) * self.embed_scale +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot class BlenderbotAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -162,17 +194,25 @@ class BlenderbotAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -193,8 +233,8 @@ class BlenderbotAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -206,69 +246,27 @@ class BlenderbotAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value - - -BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention} + return attn_output, attn_weights, past_key_value # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT @@ -277,7 +275,7 @@ class BlenderbotEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -346,7 +344,7 @@ class BlenderbotDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -360,7 +358,7 @@ class BlenderbotDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = BlenderbotAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -428,7 +426,6 @@ class BlenderbotDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -465,6 +462,9 @@ class BlenderbotPreTrainedModel(PreTrainedModel): config_class = BlenderbotConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True @@ -493,23 +493,55 @@ class BlenderbotPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -518,7 +550,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -552,7 +584,6 @@ class BlenderbotPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -618,6 +649,42 @@ class BlenderbotPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class BlenderbotEncoder(BlenderbotPreTrainedModel): """ @@ -730,10 +797,10 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -927,22 +994,28 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - ## retrieve input_ids and inputs_embeds + # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False # initialize `past_key_values` return_legacy_cache = False @@ -972,20 +1045,19 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions position_ids = self.embed_positions( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 2f778d72939..49cff8f620e 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -29,7 +29,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -37,7 +39,8 @@ from ...modeling_outputs import ( Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -48,9 +51,7 @@ from .configuration_blenderbot_small import BlenderbotSmallConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -94,6 +95,37 @@ class BlenderbotSmallLearnedPositionalEmbedding(nn.Embedding): return super().forward(position_ids) +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BlenderbotSmall class BlenderbotSmallAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -146,17 +178,25 @@ class BlenderbotSmallAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -177,8 +217,8 @@ class BlenderbotSmallAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -190,66 +230,27 @@ class BlenderbotSmallAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL @@ -258,7 +259,7 @@ class BlenderbotSmallEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotSmallAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -324,19 +325,13 @@ class BlenderbotSmallEncoderLayer(nn.Module): return outputs -# TODO: Implement attention with SDPA for TimeSeriesTransformer. -BLENDERBOT_SMALL_ATTENTION_CLASSES = { - "eager": BlenderbotSmallAttention, -} - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallDecoderLayer(nn.Module): def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotSmallAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -350,7 +345,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = BlenderbotSmallAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -424,6 +419,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -454,6 +450,9 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): config_class = BlenderbotSmallConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True @@ -482,23 +481,55 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -507,7 +538,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -541,7 +572,6 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -607,6 +637,42 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): """ @@ -718,10 +784,10 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -909,24 +975,28 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) - inputs_embeds = inputs_embeds * self.embed_scale + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False # initialize `past_key_values` return_legacy_cache = False @@ -956,20 +1026,19 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions position_ids = self.embed_positions( diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 548a362ebfd..356f48eaf94 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -25,6 +25,7 @@ from torch.nn.functional import normalize from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -405,7 +406,7 @@ class BlipMLP(nn.Module): return hidden_states -class BlipEncoderLayer(nn.Module): +class BlipEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlipConfig): super().__init__() self.embed_dim = config.hidden_size @@ -548,19 +549,12 @@ class BlipEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index f26f269c7b9..ffbca32eb9d 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -317,7 +318,7 @@ class BlipTextOutput(nn.Module): return hidden_states -class BlipTextLayer(nn.Module): +class BlipTextLayer(GradientCheckpointingLayer): def __init__(self, config, layer_num): super().__init__() self.config = config @@ -421,27 +422,15 @@ class BlipTextEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/blip/processing_blip.py b/src/transformers/models/blip/processing_blip.py index c65ff6b66fd..5970e5edbb1 100644 --- a/src/transformers/models/blip/processing_blip.py +++ b/src/transformers/models/blip/processing_blip.py @@ -55,7 +55,6 @@ class BlipProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [] image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast") tokenizer_class = ("BertTokenizer", "BertTokenizerFast") diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 3ca38af6add..ea591bf730d 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -373,7 +374,7 @@ class Blip2MLP(nn.Module): # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Blip2 -class Blip2EncoderLayer(nn.Module): +class Blip2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Blip2Config): super().__init__() self.embed_dim = config.hidden_size @@ -527,19 +528,12 @@ class Blip2Encoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -847,7 +841,7 @@ class Blip2QFormerOutput(nn.Module): return hidden_states -class Blip2QFormerLayer(nn.Module): +class Blip2QFormerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -988,31 +982,22 @@ class Blip2QFormerEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) + use_cache = False + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index 36b663dccb7..d94525f6b6f 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -21,12 +21,7 @@ from typing import List, Optional, Union from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack -from ...tokenization_utils_base import ( - AddedToken, - BatchEncoding, - PreTokenizedInput, - TextInput, -) +from ...tokenization_utils_base import AddedToken, BatchEncoding, PreTokenizedInput, TextInput from ...utils import logging @@ -67,7 +62,6 @@ class Blip2Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["num_query_tokens"] image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast") tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/camembert/tokenization_camembert_fast.py b/src/transformers/models/camembert/tokenization_camembert_fast.py index c04b5618390..05d0073da6b 100644 --- a/src/transformers/models/camembert/tokenization_camembert_fast.py +++ b/src/transformers/models/camembert/tokenization_camembert_fast.py @@ -125,10 +125,6 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast): self.vocab_file = vocab_file - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index f0c592180e9..5a364cdc34d 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -18,9 +18,18 @@ Processor class for Chameleon. from typing import List, Optional, Union +import numpy as np + from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order +from ...processing_utils import ( + MultiModalData, + ProcessingKwargs, + ProcessorMixin, + TextKwargs, + Unpack, + _validate_images_text_input_order, +) from ...tokenization_utils_base import PreTokenizedInput, TextInput @@ -34,6 +43,7 @@ class ChameleonProcessorKwargs(ProcessingKwargs, total=False): "text_kwargs": { "padding": False, "return_for_text_completion": False, + "return_mm_token_type_ids": False, }, "common_kwargs": { "return_tensors": "pt", @@ -62,7 +72,6 @@ class ChameleonProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") - valid_kwargs = ["image_seq_length", "image_token"] image_processor_class = "ChameleonImageProcessor" def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): @@ -73,6 +82,10 @@ class ChameleonProcessor(ProcessorMixin): tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "" ) # fixed tokens for start and end, so can hardcode self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "" + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.image_start_token_id = tokenizer.convert_tokens_to_ids(self.image_start_token) + self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token) + self.image_ids = [self.image_token_id, self.image_start_token_id, self.image_end_token_id] super().__init__(image_processor, tokenizer) @@ -141,14 +154,45 @@ class ChameleonProcessor(ProcessorMixin): sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode prompt_strings.append(sample) - return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) - data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) - self._check_special_mm_tokens(prompt_strings, data, modalities=["image"]) - + image_inputs = {} if images is not None: - data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) - return BatchFeature(data=data, tensor_type=return_tensors) + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`List[List[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + # add 2 for BOI and EOI tokens + num_image_tokens = [self.image_seq_length + 2] * len(image_sizes) + num_image_patches = [1] * len(image_sizes) + + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/code_llama/tokenization_code_llama_fast.py b/src/transformers/models/code_llama/tokenization_code_llama_fast.py index f6a17ebc6d1..089c5c066e7 100644 --- a/src/transformers/models/code_llama/tokenization_code_llama_fast.py +++ b/src/transformers/models/code_llama/tokenization_code_llama_fast.py @@ -168,10 +168,6 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast): self._eot_token = eot_token self.fill_token = fill_token - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor def update_post_processor(self): """ diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 37f698a86ec..0700eb8e9f6 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -261,13 +261,7 @@ class CohereAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -462,7 +456,6 @@ class CohereModel(CoherePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index a44aebcead7..e37c875be38 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -184,13 +184,7 @@ class CohereAttention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 144667f1e3d..5690864cfc5 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -222,13 +222,7 @@ class Cohere2Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -439,7 +433,6 @@ class Cohere2Model(Cohere2PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 792d278cc0a..7a5cab506e2 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -309,13 +309,7 @@ class Cohere2Attention(CohereAttention, nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -461,7 +455,6 @@ class Cohere2Model(Gemma2Model): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/colpali/processing_colpali.py b/src/transformers/models/colpali/processing_colpali.py index 2e6a68ca7af..f34681c1d4f 100644 --- a/src/transformers/models/colpali/processing_colpali.py +++ b/src/transformers/models/colpali/processing_colpali.py @@ -24,7 +24,7 @@ from typing import ClassVar, List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput from ...utils import is_torch_available @@ -90,7 +90,6 @@ class ColPaliProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast") tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast") @@ -256,6 +255,25 @@ class ColPaliProcessor(ProcessorMixin): return batch_query + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (List[List[str]], *optional*): + The input sizes formatted as (height, width) per each image. + Returns: + Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio") + to a list containing the number of placeholder tokens required. If the model doesn't accept + a certain modality or no input sizes are provided, the dict value is set to an empty list. + """ + vision_data = {} + if image_sizes is not None: + num_image_tokens = [self.image_seq_length] * len(image_sizes) + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + return MultiModalData(**vision_data) + def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please diff --git a/src/transformers/models/cpm/tokenization_cpm_fast.py b/src/transformers/models/cpm/tokenization_cpm_fast.py index ef933e084dd..48caf28c0a1 100644 --- a/src/transformers/models/cpm/tokenization_cpm_fast.py +++ b/src/transformers/models/cpm/tokenization_cpm_fast.py @@ -144,10 +144,6 @@ class CpmTokenizerFast(PreTrainedTokenizerFast): self.jieba = jieba self.translator = str.maketrans(" \n", "\u2582\u2583") - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - # Copied from transformers.models.xlnet.tokenization_xlnet_fast.XLNetTokenizerFast.build_inputs_with_special_tokens def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None diff --git a/src/transformers/models/csm/configuration_csm.py b/src/transformers/models/csm/configuration_csm.py index e6d6d2e27c6..b13b9d2a873 100644 --- a/src/transformers/models/csm/configuration_csm.py +++ b/src/transformers/models/csm/configuration_csm.py @@ -28,7 +28,7 @@ class CsmDepthDecoderConfig(PretrainedConfig): model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the csm-1b. - e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b) + e.g. [sesame/csm-1b](https://huggingface.co/sesame/csm-1b) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -210,7 +210,7 @@ class CsmConfig(PretrainedConfig): model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the csm-1b. - e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b) + e.g. [sesame/csm-1b](https://huggingface.co/sesame/csm-1b) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. diff --git a/src/transformers/models/csm/generation_csm.py b/src/transformers/models/csm/generation_csm.py index 2fec3ea8919..7afc7c2d60c 100644 --- a/src/transformers/models/csm/generation_csm.py +++ b/src/transformers/models/csm/generation_csm.py @@ -415,7 +415,7 @@ class CsmGenerationMixin(GenerationMixin): >>> from transformers import CsmProcessor, CsmForConditionalGeneration >>> from datasets import load_dataset, Audio - >>> model_id = "eustlb/csm-1b" + >>> model_id = "sesame/csm-1b" >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" >>> processor = AutoProcessor.from_pretrained(model_id) diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 6f8fd7a487f..c0c4f5927a5 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -337,15 +337,8 @@ class CsmAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -516,7 +509,6 @@ class CsmDepthDecoderModel(CsmPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -828,7 +820,6 @@ class CsmBackboneModel(CsmPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -990,22 +981,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): # ======================================= # TODO: @eustlb, this should be batched !!! # but requires making sure batched inference of the codec model works as intended - audio_tokens_list = [] - for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): - batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] - for i in range(batch_input_values_cutoffs.shape[0] - 1): - start_idx = batch_input_values_cutoffs[i] - end_idx = batch_input_values_cutoffs[i + 1] - audio_batch = batch_input_values[..., start_idx:end_idx] - codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) - codebook_ids = codec_outputs.audio_codes.transpose(1, -1) - audio_tokens_list.append(codebook_ids[0]) + with torch.no_grad(): + audio_tokens_list = [] + for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): + batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] + for i in range(batch_input_values_cutoffs.shape[0] - 1): + start_idx = batch_input_values_cutoffs[i] + end_idx = batch_input_values_cutoffs[i + 1] + audio_batch = batch_input_values[..., start_idx:end_idx] + codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) + codebook_ids = codec_outputs.audio_codes.transpose(1, -1) + audio_tokens_list.append(codebook_ids[0]) - max_audio_frames = max(el.shape[0] for el in audio_tokens_list) - batched_audio_token_ids = torch.stack( - [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] - ) - audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) + max_audio_frames = max(el.shape[0] for el in audio_tokens_list) + batched_audio_token_ids = torch.stack( + [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] + ) + audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) # ======================================= audio_token_id = self.config.audio_token_id audio_token_mask = input_ids == audio_token_id @@ -1027,6 +1019,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): if labels is not None: labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks) labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask] + labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids # mask depth decoder depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True) labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100 @@ -1120,7 +1113,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): >>> from transformers import CsmForConditionalGeneration, AutoProcessor >>> from datasets import load_dataset, Audio - >>> model_id = "eustlb/csm-1b" + >>> model_id = "sesame/csm-1b" >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" >>> processor = AutoProcessor.from_pretrained(model_id) diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 35fdf127fcd..4322a2a07f8 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -247,7 +247,6 @@ class CsmDepthDecoderModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -596,22 +595,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): # ======================================= # TODO: @eustlb, this should be batched !!! # but requires making sure batched inference of the codec model works as intended - audio_tokens_list = [] - for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): - batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] - for i in range(batch_input_values_cutoffs.shape[0] - 1): - start_idx = batch_input_values_cutoffs[i] - end_idx = batch_input_values_cutoffs[i + 1] - audio_batch = batch_input_values[..., start_idx:end_idx] - codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) - codebook_ids = codec_outputs.audio_codes.transpose(1, -1) - audio_tokens_list.append(codebook_ids[0]) + with torch.no_grad(): + audio_tokens_list = [] + for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): + batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] + for i in range(batch_input_values_cutoffs.shape[0] - 1): + start_idx = batch_input_values_cutoffs[i] + end_idx = batch_input_values_cutoffs[i + 1] + audio_batch = batch_input_values[..., start_idx:end_idx] + codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) + codebook_ids = codec_outputs.audio_codes.transpose(1, -1) + audio_tokens_list.append(codebook_ids[0]) - max_audio_frames = max(el.shape[0] for el in audio_tokens_list) - batched_audio_token_ids = torch.stack( - [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] - ) - audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) + max_audio_frames = max(el.shape[0] for el in audio_tokens_list) + batched_audio_token_ids = torch.stack( + [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] + ) + audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) # ======================================= audio_token_id = self.config.audio_token_id audio_token_mask = input_ids == audio_token_id @@ -633,6 +633,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): if labels is not None: labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks) labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask] + labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids # mask depth decoder depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True) labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100 @@ -726,7 +727,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): >>> from transformers import CsmForConditionalGeneration, AutoProcessor >>> from datasets import load_dataset, Audio - >>> model_id = "eustlb/csm-1b" + >>> model_id = "sesame/csm-1b" >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" >>> processor = AutoProcessor.from_pretrained(model_id) diff --git a/src/transformers/models/csm/processing_csm.py b/src/transformers/models/csm/processing_csm.py index 486c5eda4c7..955f73cb363 100644 --- a/src/transformers/models/csm/processing_csm.py +++ b/src/transformers/models/csm/processing_csm.py @@ -31,10 +31,7 @@ if is_soundfile_available(): from ...audio_utils import AudioInput, make_list_of_audio from ...feature_extraction_utils import BatchFeature from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack -from ...tokenization_utils_base import ( - PreTokenizedInput, - TextInput, -) +from ...tokenization_utils_base import PreTokenizedInput, TextInput class CsmAudioKwargs(AudioKwargs, total=False): @@ -76,7 +73,7 @@ class CsmProcessor(ProcessorMixin): ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") audio = ds[0]["audio"]["array"] - processor = CsmProcessor.from_pretrained("eustlb/csm-1b") + processor = CsmProcessor.from_pretrained("sesame/csm-1b") processor( text=["<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"], @@ -99,7 +96,6 @@ class CsmProcessor(ProcessorMixin): """ attributes = ["feature_extractor", "tokenizer"] - valid_kwargs = ["chat_template"] feature_extractor_class = "EncodecFeatureExtractor" tokenizer_class = "PreTrainedTokenizerFast" @@ -353,7 +349,11 @@ class CsmProcessor(ProcessorMixin): else: skip_frames_idxs = audio_frame_idxs - labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100) + labels = torch.where( + (data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id), + data["input_ids"], + -100, + ) labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101 data["labels"] = labels diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index e1a822ea037..eafcbff89ae 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -4,9 +4,24 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_data2vec_audio.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -16,7 +31,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -25,16 +41,14 @@ from ...modeling_outputs import ( Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_peft_available, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available from .configuration_data2vec_audio import Data2VecAudioConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - -logger = logging.get_logger(__name__) +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask class Data2VecAudioConvLayer(nn.Module): @@ -167,6 +181,36 @@ class Data2VecAudioFeatureProjection(nn.Module): return hidden_states, norm_hidden_states +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Data2VecAudioAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -201,9 +245,6 @@ class Data2VecAudioAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -212,6 +253,9 @@ class Data2VecAudioAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -219,10 +263,16 @@ class Data2VecAudioAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -237,18 +287,18 @@ class Data2VecAudioAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -260,298 +310,29 @@ class Data2VecAudioAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class Data2VecAudioFlashAttention2(Data2VecAudioAttention): - """ - Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class Data2VecAudioSdpaAttention(Data2VecAudioAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Data2VecAudioModel is using Data2VecAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class Data2VecAudioFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -576,21 +357,15 @@ class Data2VecAudioFeedForward(nn.Module): return hidden_states -DATA2VEC_AUDIO_ATTENTION_CLASSES = { - "eager": Data2VecAudioAttention, - "sdpa": Data2VecAudioSdpaAttention, - "flash_attention_2": Data2VecAudioFlashAttention2, -} - - class Data2VecAudioEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = DATA2VEC_AUDIO_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = Data2VecAudioAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -627,7 +402,6 @@ class Data2VecAudioEncoder(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -644,16 +418,11 @@ class Data2VecAudioEncoder(nn.Module): # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -702,6 +471,28 @@ class Data2VecAudioEncoder(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class Data2VecAudioAdapterLayer(nn.Module): def __init__(self, config): @@ -760,6 +551,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 58934d2e86a..0b4695c1e28 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Data2VecText model.""" + import math import torch @@ -124,6 +140,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py b/src/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py index 784e8299541..096b4b239c6 100644 --- a/src/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py +++ b/src/transformers/models/deberta_v2/tokenization_deberta_v2_fast.py @@ -119,10 +119,6 @@ class DebertaV2TokenizerFast(PreTrainedTokenizerFast): self.split_by_punct = split_by_punct self.vocab_file = vocab_file - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index b15301e2884..5804eeee4b1 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -412,13 +412,7 @@ class DeepseekV3Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -608,7 +602,6 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index b4905c62011..e7d5eaded7e 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -293,13 +293,7 @@ class DeepseekV3Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py b/src/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py index 1a5da12859f..4b452554ea0 100644 --- a/src/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py +++ b/src/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py @@ -169,10 +169,6 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer): **kwargs, ) - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 84df7b4d41f..68aa54180ca 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -397,23 +397,6 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -708,7 +691,6 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index f7bc2d2c5ac..b772a9f04d5 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -330,23 +330,6 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) diff --git a/src/transformers/models/dpt/image_processing_dpt.py b/src/transformers/models/dpt/image_processing_dpt.py index a22548f5cd9..b948f9886e4 100644 --- a/src/transformers/models/dpt/image_processing_dpt.py +++ b/src/transformers/models/dpt/image_processing_dpt.py @@ -213,8 +213,6 @@ class DPTImageProcessor(BaseImageProcessor): resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size specified in `size`. - resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): - Resampling filter to use when resiizing the image. data_format (`str` or `ChannelDimension`, *optional*): The channel dimension format of the image. If not provided, it will be the same as the input image. input_data_format (`str` or `ChannelDimension`, *optional*): diff --git a/src/transformers/models/emu3/convert_emu3_weights_to_hf.py b/src/transformers/models/emu3/convert_emu3_weights_to_hf.py index 8ac8db7e429..ddc907adb4c 100644 --- a/src/transformers/models/emu3/convert_emu3_weights_to_hf.py +++ b/src/transformers/models/emu3/convert_emu3_weights_to_hf.py @@ -211,14 +211,13 @@ def convert_tiktoken(tokenizer, output_dir): KEYS_TO_MODIFY_MAPPING = { + "^model": "model.text_model", "^encoder": "model.vqmodel.encoder", "^decoder": "model.vqmodel.decoder", "^post_quant_conv": "model.vqmodel.post_quant_conv", "^quant_conv": "model.vqmodel.quant_conv", "^quantize": "model.vqmodel.quantize", - "^model": "text_model.model", - r"lm_head\.weight": "text_model.lm_head.weight", - r"^text_model\.model\.vqmodel": "vqmodel", + r"lm_head\.weight": "lm_head.weight", # rename QKV proj for the VQ-VAE model because we use SiglipAttention r"\.q\.": ".q_proj.", r"\.k\.": ".k_proj.", diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 3b570fd1f26..31f01db1b5a 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -206,15 +206,8 @@ class Emu3Attention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -1279,7 +1272,6 @@ class Emu3TextModel(Emu3PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds @@ -1446,9 +1438,6 @@ class Emu3Model(Emu3PreTrainedModel): def __init__(self, config): super().__init__(config) self.text_model = Emu3TextModel._from_config(config.text_config) - if self.text_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys] - self.vqmodel = Emu3VQVAE(config.vq_config) self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) @@ -1569,6 +1558,7 @@ class Emu3Model(Emu3PreTrainedModel): class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): base_model_prefix = "" + _tied_weights_keys = ["lm_head.weight"] _checkpoint_conversion_mapping = { "^text_model.model": "model.text_model", "^vqmodel": "model.vqmodel", @@ -1589,6 +1579,18 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def text_model(self): @@ -1598,6 +1600,13 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): def vqmodel(self): return self.model.vqmodel + @property + def vocabulary_mapping(self): + return self.model.vocabulary_mapping + + def decode_image_tokens(self, **kwargs): + return self.model.decode_image_tokens(**kwargs) + @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index bf2e6a5efa7..8c86f81d523 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -925,9 +925,6 @@ class Emu3Model(Emu3PreTrainedModel): def __init__(self, config): super().__init__(config) self.text_model = Emu3TextModel._from_config(config.text_config) - if self.text_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys] - self.vqmodel = Emu3VQVAE(config.vq_config) self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) @@ -1048,6 +1045,7 @@ class Emu3Model(Emu3PreTrainedModel): class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): base_model_prefix = "" + _tied_weights_keys = ["lm_head.weight"] _checkpoint_conversion_mapping = { "^text_model.model": "model.text_model", "^vqmodel": "model.vqmodel", @@ -1068,6 +1066,18 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def text_model(self): @@ -1077,6 +1087,13 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): def vqmodel(self): return self.model.vqmodel + @property + def vocabulary_mapping(self): + return self.model.vocabulary_mapping + + def decode_image_tokens(self, **kwargs): + return self.model.decode_image_tokens(**kwargs) + @can_return_tuple @auto_docstring def forward( diff --git a/src/transformers/models/emu3/processing_emu3.py b/src/transformers/models/emu3/processing_emu3.py index a94dc08cd97..61b40217723 100644 --- a/src/transformers/models/emu3/processing_emu3.py +++ b/src/transformers/models/emu3/processing_emu3.py @@ -16,10 +16,17 @@ from typing import List, Optional, Union +import numpy as np + from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_vision_available + + +if is_vision_available(): + from .image_processing_emu3 import smart_resize class Emu3TextKwargs(TextKwargs, total=False): @@ -37,6 +44,7 @@ class Emu3ProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "return_for_image_generation": False, + "return_mm_token_type_ids": False, }, "images_kwargs": { "ratio": "1:1", @@ -63,7 +71,6 @@ class Emu3Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") image_processor_class = "Emu3ImageProcessor" @@ -166,7 +173,7 @@ class Emu3Processor(ProcessorMixin): image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'' * image_seq_length}{image_end_tokens}" sample = sample.replace(self.image_token, image_placeholder, 1) - sample = f"{self.bos_token}{sample}" # add BOS because PT tokenizer doesn't add it + sample = f"{self.bos_token}{sample}" # add BOS because GPT tokenizer doesn't add it prompt_strings.append(sample) text = [sample.replace("", self.image_token) for sample in prompt_strings] @@ -179,12 +186,51 @@ class Emu3Processor(ProcessorMixin): # else just generate from text-only input, and we do no special treatment for text return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) - data = self.tokenizer(text, **output_kwargs["text_kwargs"]) - self._check_special_mm_tokens(text, data, modalities=["image"]) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) - data.update(**image_features) + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() - return BatchFeature(data=data, tensor_type=return_tensors) + return BatchFeature(data={**text_inputs, **image_features}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`List[List[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + num_image_tokens = [] + for height, width in image_sizes: + height, width = smart_resize( + height, + width, + self.image_processor.spatial_factor, + self.image_processor.min_pixels, + self.image_processor.max_pixels, + ) + height = height // self.downsample_ratio + width = width // self.downsample_ratio + image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code + num_image_tokens.append(image_seq_length) + + num_image_patches = [1] * len(image_sizes) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) def calculate_generate_size(self, ratio, image_area, spatial_factor): width, height = map(int, ratio.split(":")) diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 3a2e20e7cc1..0a4d8f43277 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -45,8 +45,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torchdynamo_compiling, logging, replace_return_docstrings -from ...utils.deprecation import deprecate_kwarg +from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_falcon_h1 import FalconH1Config @@ -65,8 +64,6 @@ else: logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "FalconH1Config" - class FalconHybridMambaAttentionDynamicCache(DynamicCache): """ @@ -383,13 +380,7 @@ class FalconH1Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -610,9 +601,10 @@ class FalconH1Mixer(nn.Module): ): # 1. Gated MLP's linear projection hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + # Add Multipliers hidden_states = hidden_states * self.ssm_in_multiplier projected_states = self.in_proj(hidden_states) - projected_states = projected_states * self.mup_vector + projected_states = projected_states * self.mup_vector # ADD Mup Multipliers d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # Set up dimensions for reshapes later @@ -806,10 +798,13 @@ class FalconH1Mixer(nn.Module): # 1. Gated MLP's linear projection input_states = apply_mask_to_padding_states(input_states, attention_mask) + # Add Multipliers + input_states = input_states * self.ssm_in_multiplier projected_states = self.in_proj(input_states) - gate, hidden_states_B_C, dt = projected_states.split( - [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 - ) + projected_states = projected_states * self.mup_vector # ADD Mup Multipliers + gate, hidden_states_B_C, dt = projected_states.split([ + self.intermediate_size, self.conv_dim, self.num_heads + ], dim=-1) use_precomputed_states = ( cache_params is not None @@ -920,8 +915,8 @@ class FalconH1Mixer(nn.Module): hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) @@ -1226,13 +1221,6 @@ def compute_mup_vector(config): @auto_docstring # Adapted from transformers.models.jamba.modeling_jamba.JambaModel class FalconH1Model(FalconH1PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FalconH1DecoderLayer`] - - Args: - config: FalconH1Config - """ - def __init__(self, config: FalconH1Config): super().__init__(config) self.padding_idx = config.pad_token_id @@ -1266,6 +1254,7 @@ class FalconH1Model(FalconH1PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @auto_docstring def forward( self, @@ -1277,7 +1266,6 @@ class FalconH1Model(FalconH1PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, # NOOP kwargs, for now ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -1287,8 +1275,6 @@ class FalconH1Model(FalconH1PreTrainedModel): ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1358,8 +1344,6 @@ class FalconH1Model(FalconH1PreTrainedModel): next_cache = None if not use_cache else past_key_values - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1528,9 +1512,8 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): def get_decoder(self): return self.model - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @can_return_tuple @auto_docstring - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -1542,7 +1525,6 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, @@ -1553,15 +1535,6 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - Example: ```python @@ -1582,7 +1555,6 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -1594,7 +1566,6 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) @@ -1608,10 +1579,6 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 07b9e540848..bd0ecb1804d 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -51,24 +51,11 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - auto_docstring, - is_torchdynamo_compiling, - logging, - replace_return_docstrings, -) -from ...utils.deprecation import deprecate_kwarg -from ...utils.import_utils import ( - is_causal_conv1d_available, - is_flash_attn_2_available, - is_mamba_2_ssm_available, -) +from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_falcon_h1 import FalconH1Config -if is_flash_attn_2_available(): - pass - if is_mamba_2_ssm_available(): from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined @@ -85,8 +72,6 @@ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_c logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "FalconH1Config" - class FalconHybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): """ @@ -251,13 +236,7 @@ class FalconH1Attention(LlamaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -421,9 +400,10 @@ class FalconH1Mixer(nn.Module): ): # 1. Gated MLP's linear projection hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + # Add Multipliers hidden_states = hidden_states * self.ssm_in_multiplier projected_states = self.in_proj(hidden_states) - projected_states = projected_states * self.mup_vector + projected_states = projected_states * self.mup_vector # ADD Mup Multipliers d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # Set up dimensions for reshapes later @@ -617,10 +597,13 @@ class FalconH1Mixer(nn.Module): # 1. Gated MLP's linear projection input_states = apply_mask_to_padding_states(input_states, attention_mask) + # Add Multipliers + input_states = input_states * self.ssm_in_multiplier projected_states = self.in_proj(input_states) - gate, hidden_states_B_C, dt = projected_states.split( - [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 - ) + projected_states = projected_states * self.mup_vector # ADD Mup Multipliers + gate, hidden_states_B_C, dt = projected_states.split([ + self.intermediate_size, self.conv_dim, self.num_heads + ], dim=-1) use_precomputed_states = ( cache_params is not None @@ -731,8 +714,8 @@ class FalconH1Mixer(nn.Module): hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) @@ -1013,13 +996,6 @@ def compute_mup_vector(config): @auto_docstring # Adapted from transformers.models.jamba.modeling_jamba.JambaModel class FalconH1Model(FalconH1PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`FalconH1DecoderLayer`] - - Args: - config: FalconH1Config - """ - def __init__(self, config: FalconH1Config): super().__init__(config) self.padding_idx = config.pad_token_id @@ -1053,6 +1029,7 @@ class FalconH1Model(FalconH1PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value + @can_return_tuple @auto_docstring def forward( self, @@ -1064,7 +1041,6 @@ class FalconH1Model(FalconH1PreTrainedModel): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, # NOOP kwargs, for now ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -1074,8 +1050,6 @@ class FalconH1Model(FalconH1PreTrainedModel): ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1145,8 +1119,6 @@ class FalconH1Model(FalconH1PreTrainedModel): next_cache = None if not use_cache else past_key_values - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1283,9 +1255,6 @@ class FalconH1Model(FalconH1PreTrainedModel): class FalconH1ForCausalLM(LlamaForCausalLM): - @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @auto_docstring - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -1297,26 +1266,11 @@ class FalconH1ForCausalLM(LlamaForCausalLM): use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - Example: ```python @@ -1337,7 +1291,6 @@ class FalconH1ForCausalLM(LlamaForCausalLM): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -1349,7 +1302,6 @@ class FalconH1ForCausalLM(LlamaForCausalLM): use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) @@ -1363,10 +1315,6 @@ class FalconH1ForCausalLM(LlamaForCausalLM): if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, diff --git a/src/transformers/models/fnet/tokenization_fnet_fast.py b/src/transformers/models/fnet/tokenization_fnet_fast.py index ac33bc13c60..9550bcbb4ae 100644 --- a/src/transformers/models/fnet/tokenization_fnet_fast.py +++ b/src/transformers/models/fnet/tokenization_fnet_fast.py @@ -113,10 +113,6 @@ class FNetTokenizerFast(PreTrainedTokenizerFast): self.keep_accents = keep_accents self.vocab_file = vocab_file - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 4f120afa1ba..590cf8f8d17 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -130,7 +130,7 @@ class FuyuModel(FuyuPreTrainedModel): ) return output_embeddings - def get_image_features(self, pixel_values: torch.FloatTensor): + def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs): """ Encodes images into continuous embeddings that can be forwarded to the language model. diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index 960b69ed31b..4852f3aaf9e 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -22,7 +22,13 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order +from ...processing_utils import ( + MultiModalData, + ProcessingKwargs, + ProcessorMixin, + Unpack, + _validate_images_text_input_order, +) from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_torch_available, logging, requires_backends from ...utils.import_utils import requires @@ -64,6 +70,7 @@ class FuyuProcessorKwargs(ProcessingKwargs, total=False): "return_token_type_ids": False, "return_length": False, "verbose": True, + "return_mm_token_type_ids": False, }, "images_kwargs": {}, } @@ -343,7 +350,6 @@ class FuyuProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = [] image_processor_class = "FuyuImageProcessor" tokenizer_class = "AutoTokenizer" @@ -355,6 +361,8 @@ class FuyuProcessor(ProcessorMixin): self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it? self.pad_token_id = 0 self.dummy_image_index = -1 + self.image_token_id = tokenizer.encode("|SPEAKER|", add_special_tokens=False)[1] + self.image_newline_id = tokenizer.encode("|NEWLINE|", add_special_tokens=False)[1] def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool): max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs) @@ -403,6 +411,11 @@ class FuyuProcessor(ProcessorMixin): for key in batched_keys: batched_inputs[key] = torch.cat(batched_inputs[key], dim=0) + # Cast images to tensor as well, if only one image passed and no padding needed + # NOTE: vLLM expects all processor outputs to be a tensor + if len(batched_inputs["image_patches"]) == 1: + batched_inputs["image_patches"] = torch.cat(batched_inputs["image_patches"], dim=0) + return batched_inputs def get_sample_encoding( @@ -517,6 +530,7 @@ class FuyuProcessor(ProcessorMixin): tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True): raise ValueError("`return_attention_mask=False` is not supported for this model.") @@ -550,8 +564,6 @@ class FuyuProcessor(ProcessorMixin): # --- Use self.tokenizer to get the ids of special tokens to insert into image ids --- - image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1] - image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1] tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1) # --- Use self.image_processor again to obtain the full token ids and batch inputs --- @@ -565,16 +577,63 @@ class FuyuProcessor(ProcessorMixin): scale_factors=[scale_factor], image_unpadded_heights=torch.tensor([image_unpadded_height]), image_unpadded_widths=torch.tensor([image_unpadded_width]), - image_placeholder_id=image_placeholder_id, - image_newline_id=image_newline_id, + image_placeholder_id=self.image_token_id, + image_newline_id=self.image_newline_id, tensor_batch_images=tensor_batch_image.unsqueeze(0), ) all_encodings.append(sample_encoding) + batch_encoding = self._left_pad_inputs_with_attention_mask( model_inputs=all_encodings, return_attention_mask=True ) + if return_mm_token_type_ids: + input_ids = batch_encoding["input_ids"] + mm_token_type_ids = torch.zeros_like(input_ids) + mm_token_type_ids[input_ids == self.image_token_id] = 1 + mm_token_type_ids[input_ids == self.image_newline_id] = 1 + batch_encoding["mm_token_type_ids"] = mm_token_type_ids + return FuyuBatchFeature(data=batch_encoding) + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`List[List[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + size = kwargs.get("size", None) or self.image_processor.size + padded_height, padded_width = size["height"], size["width"] + + num_image_tokens = [] + num_image_patches = [1] * len(image_sizes) + for image_size in image_sizes: + height_scale_factor = padded_height / image_size[0] + width_scale_factor = padded_width / image_size[1] + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + + # We can use torch here because Fuyu processor has hard dependency on torch + model_image_input = self.image_processor.preprocess_with_tokenizer_info( + image_input=torch.zeros(1, 1, 3, padded_height, padded_width), + image_present=torch.ones(1, 1, 1), + image_unpadded_h=torch.tensor([[int(image_size[0] * optimal_scale_factor)]]), + image_unpadded_w=torch.tensor([[int(image_size[1] * optimal_scale_factor)]]), + image_placeholder_id=0, # dummy ids, we can be sure `id=0` is never out-of-range + image_newline_id=0, + variable_sized=True, + ) + num_image_tokens.append(model_image_input["image_input_ids"][0][0].shape[-1]) + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + return MultiModalData(**vision_data) + def post_process_box_coordinates(self, outputs, target_sizes=None): """ Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space. diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 897f329e56c..2a296089198 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -239,15 +239,8 @@ class GemmaAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -422,7 +415,6 @@ class GemmaModel(GemmaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) # embed positions diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 1a1e8cc1c63..e934df7ef80 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -416,7 +416,6 @@ class GemmaModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) # embed positions diff --git a/src/transformers/models/gemma/tokenization_gemma_fast.py b/src/transformers/models/gemma/tokenization_gemma_fast.py index 24e2c90c307..bc6e0c8ba7c 100644 --- a/src/transformers/models/gemma/tokenization_gemma_fast.py +++ b/src/transformers/models/gemma/tokenization_gemma_fast.py @@ -114,10 +114,6 @@ class GemmaTokenizerFast(PreTrainedTokenizerFast): self.update_post_processor() self.vocab_file = vocab_file - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - # Copied from transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast.update_post_processor def update_post_processor(self): """ diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fe5576ae1c8..7bb865bc5dc 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -218,13 +218,7 @@ class Gemma2Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -445,7 +439,6 @@ class Gemma2Model(Gemma2PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 7d0b721d809..31b251f4ca7 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -283,13 +283,7 @@ class Gemma2Attention(GemmaAttention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -428,7 +422,6 @@ class Gemma2Model(GemmaModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 122d16aafce..08740173009 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -345,14 +345,7 @@ class Gemma3Attention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -566,7 +559,6 @@ class Gemma3TextModel(Gemma3PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { @@ -790,7 +782,7 @@ class Gemma3MultiModalProjector(nn.Module): return projected_vision_outputs.type_as(vision_outputs) -def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]: +def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. @@ -800,8 +792,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: - # If it's 1, we need to unmask it - return token_type_ids[batch_idx, kv_idx] == 1 + # If the difference is less than image size, both are part of the same image block + same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image + # If it's 1 for both query and key/value, we are in an image block + is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1) + + # This is bidirectional attention whenever we are dealing with image tokens + return is_image_block & same_image_block return inner_mask @@ -949,12 +946,11 @@ class Gemma3Model(Gemma3PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` mask_kwargs["or_mask_function"] = token_type_ids_mask_function( - token_type_ids.to(cache_position.device) + token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image ) # Create the masks @@ -1017,6 +1013,12 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): @@ -1082,7 +1084,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): >>> inputs = processor.apply_chat_template( ... messages, - ... tokenizer=True, + ... tokenize=True, ... return_dict=True, ... return_tensors="pt", ... add_generation_prompt=True @@ -1200,7 +1202,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ) -> dict: @@ -1211,12 +1212,13 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` - mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device)) + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device), config.mm_tokens_per_image + ) return create_masks_for_generate(**mask_kwargs) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index f0761d863d1..d679d30c8b9 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -424,14 +424,7 @@ class Gemma3Attention(Gemma2Attention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -617,7 +610,6 @@ class Gemma3TextModel(Gemma2Model): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Create the masks causal_mask_mapping = { @@ -730,7 +722,7 @@ class Gemma3MultiModalProjector(nn.Module): return projected_vision_outputs.type_as(vision_outputs) -def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]: +def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. @@ -740,8 +732,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: - # If it's 1, we need to unmask it - return token_type_ids[batch_idx, kv_idx] == 1 + # If the difference is less than image size, both are part of the same image block + same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image + # If it's 1 for both query and key/value, we are in an image block + is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1) + + # This is bidirectional attention whenever we are dealing with image tokens + return is_image_block & same_image_block return inner_mask @@ -840,12 +837,11 @@ class Gemma3Model(PaliGemmaModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` mask_kwargs["or_mask_function"] = token_type_ids_mask_function( - token_type_ids.to(cache_position.device) + token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image ) # Create the masks @@ -929,7 +925,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): >>> inputs = processor.apply_chat_template( ... messages, - ... tokenizer=True, + ... tokenize=True, ... return_dict=True, ... return_tensors="pt", ... add_generation_prompt=True @@ -1050,7 +1046,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], - output_attentions: bool = False, token_type_ids: Optional[torch.Tensor] = None, **kwargs, ) -> dict: @@ -1061,12 +1056,13 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, - "output_attentions": output_attentions, } # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` - mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device)) + mask_kwargs["or_mask_function"] = token_type_ids_mask_function( + token_type_ids.to(cache_position.device), config.mm_tokens_per_image + ) return create_masks_for_generate(**mask_kwargs) diff --git a/src/transformers/models/gemma3/processing_gemma3.py b/src/transformers/models/gemma3/processing_gemma3.py index f887e11d5c1..ab6f03290a7 100644 --- a/src/transformers/models/gemma3/processing_gemma3.py +++ b/src/transformers/models/gemma3/processing_gemma3.py @@ -20,7 +20,7 @@ import numpy as np from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, make_nested_list_of_images -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import to_py_obj @@ -38,6 +38,7 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": False, + "return_mm_token_type_ids": True, }, "images_kwargs": { "do_pan_and_scan": False, @@ -50,7 +51,6 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): class Gemma3Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template", "image_seq_length"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" @@ -137,17 +137,42 @@ class Gemma3Processor(ProcessorMixin): text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text] return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) - text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np") + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) # Add token type ids manually, as tokenizer can't do arbitrary position token types - array_ids = text_inputs["input_ids"] - mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) - mm_token_type_ids[array_ids == self.image_token_id] = 1 - text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs - text_inputs["token_type_ids"] = mm_token_type_ids.tolist() + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(array_ids) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`List[List[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + # NOTE: no image cropping supported yet + num_image_tokens = [self.image_seq_length] * len(image_sizes) + num_image_patches = [1] * len(image_sizes) + + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index f3ac600e22b..235f8258c10 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -201,15 +201,8 @@ class GlmAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -443,7 +436,6 @@ class GlmModel(GlmPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 4525ba15018..f32bfb3a392 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -259,15 +259,8 @@ class Glm4Attention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -451,7 +444,6 @@ class Glm4Model(Glm4PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py index dc06f1ef391..d706c0f3403 100644 --- a/src/transformers/models/got_ocr2/image_processing_got_ocr2.py +++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2.py @@ -491,5 +491,33 @@ class GotOcr2ImageProcessor(BaseImageProcessor): return processed_images + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of patches per image. + """ + min_patches = images_kwargs.get("min_patches", None) or self.min_patches + max_patches = images_kwargs.get("max_patches", None) or self.max_patches + patch_size = images_kwargs.get("size", None) or self.size + crop_to_patches = images_kwargs.get("crop_to_patches", None) or self.crop_to_patches + + num_patches = 1 + if crop_to_patches and max_patches > 1: + num_columns, num_rows = get_optimal_tiled_canvas( + (height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches + ) + num_patches += num_columns * num_rows + + return num_patches + __all__ = ["GotOcr2ImageProcessor"] diff --git a/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py index 3b8b4c2f560..95179d7a94c 100644 --- a/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +++ b/src/transformers/models/got_ocr2/image_processing_got_ocr2_fast.py @@ -228,5 +228,33 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast): data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors ) + def get_number_of_image_tokens(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of patches per image. + """ + min_patches = images_kwargs.get("min_patches", None) or self.min_patches + max_patches = images_kwargs.get("max_patches", None) or self.max_patches + patch_size = images_kwargs.get("size", None) or self.size + crop_to_patches = images_kwargs.get("crop_to_patches", None) or self.crop_to_patches + + num_patches = 1 + if crop_to_patches and max_patches > 1: + num_columns, num_rows = get_optimal_tiled_canvas( + (height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches + ) + num_patches += num_columns * num_rows + + return num_patches + __all__ = ["GotOcr2ImageProcessorFast"] diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 6da4405fad5..0d6b44214ba 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -755,6 +755,12 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/got_ocr2/processing_got_ocr2.py b/src/transformers/models/got_ocr2/processing_got_ocr2.py index 5e40d14dee8..b712245a64c 100644 --- a/src/transformers/models/got_ocr2/processing_got_ocr2.py +++ b/src/transformers/models/got_ocr2/processing_got_ocr2.py @@ -95,7 +95,6 @@ class GotOcr2Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] image_processor_class = "AutoImageProcessor" tokenizer_class = "PreTrainedTokenizerFast" diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9c32acdb06a..16de0f23db9 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -166,28 +166,9 @@ class GPTNeoXAttention(nn.Module): } key_states, value_states = layer_past.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Checking for fallbacks in case an unsupported feature is requested - attention_type = self.config._attn_implementation - if (output_attentions or head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - ]: - logger.warning_once( - f"Setting `attention_type` to `eager` because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - attention_type = "eager" - - elif self.training and self.attention_dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Setting `attention_type` to `eager` because `dropout` is not supported in `{attention_type}`." - ) - attention_type = "eager" - attention_interface: Callable = eager_attention_forward - attention_interface = ( - ALL_ATTENTION_FUNCTIONS[attention_type] if attention_type != "eager" else attention_interface - ) + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] # Compute attention attn_output, attn_weights = attention_interface( @@ -409,7 +390,6 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) # Prepare head mask if needed diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 70bee31b280..e7d67a97644 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -153,28 +153,9 @@ class GPTNeoXAttention(nn.Module): } key_states, value_states = layer_past.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Checking for fallbacks in case an unsupported feature is requested - attention_type = self.config._attn_implementation - if (output_attentions or head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - ]: - logger.warning_once( - f"Setting `attention_type` to `eager` because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - attention_type = "eager" - - elif self.training and self.attention_dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Setting `attention_type` to `eager` because `dropout` is not supported in `{attention_type}`." - ) - attention_type = "eager" - attention_interface: Callable = eager_attention_forward - attention_interface = ( - ALL_ATTENTION_FUNCTIONS[attention_type] if attention_type != "eager" else attention_interface - ) + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] # Compute attention attn_output, attn_weights = attention_interface( @@ -356,7 +337,6 @@ class GPTNeoXModel(LlamaModel, nn.Module): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) # Prepare head mask if needed diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index fdba3f4c0eb..11f2873f3df 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -165,15 +165,8 @@ class GraniteAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -446,7 +439,6 @@ class GraniteModel(GranitePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 424a0cc3fa2..33f3b3363e9 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -181,7 +181,6 @@ class GraniteModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/granite_speech/processing_granite_speech.py b/src/transformers/models/granite_speech/processing_granite_speech.py index ec36eb49703..9032601a6b2 100644 --- a/src/transformers/models/granite_speech/processing_granite_speech.py +++ b/src/transformers/models/granite_speech/processing_granite_speech.py @@ -31,8 +31,6 @@ logger = logging.get_logger(__name__) class GraniteSpeechProcessor(ProcessorMixin): attributes = ["audio_processor", "tokenizer"] - valid_kwargs = ["audio_token"] - audio_processor_class = "GraniteSpeechFeatureExtractor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index fdd7addc450..a3a314a6abb 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -452,13 +452,7 @@ class GraniteMoeAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 48e1dc0020f..d6ff36bf324 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -203,13 +203,7 @@ class GraniteMoeHybridAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 0845ba7b696..dc429aa55bc 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -387,13 +387,7 @@ class GraniteMoeSharedAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 5d58ca59458..b9cb3bafc13 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -241,15 +241,8 @@ class HeliumAttention(nn.Module): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -428,7 +421,6 @@ class HeliumModel(HeliumPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, - output_attentions=output_attentions, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index d920e998f97..115345407e6 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -4,8 +4,23 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_hubert.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -15,15 +30,17 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_hubert import HubertConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -224,6 +241,36 @@ class HubertFeatureProjection(nn.Module): return hidden_states +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class HubertAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -258,9 +305,6 @@ class HubertAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -269,6 +313,9 @@ class HubertAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -276,10 +323,16 @@ class HubertAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -294,18 +347,18 @@ class HubertAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -317,298 +370,29 @@ class HubertAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class HubertFlashAttention2(HubertAttention): - """ - Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class HubertSdpaAttention(HubertAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "HubertModel is using HubertSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class HubertFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -633,21 +417,15 @@ class HubertFeedForward(nn.Module): return hidden_states -HUBERT_ATTENTION_CLASSES = { - "eager": HubertAttention, - "sdpa": HubertSdpaAttention, - "flash_attention_2": HubertFlashAttention2, -} - - class HubertEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = HubertAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -684,7 +462,6 @@ class HubertEncoder(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -701,16 +478,11 @@ class HubertEncoder(nn.Module): # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -759,6 +531,28 @@ class HubertEncoder(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class HubertAttnAdapterLayer(nn.Module): def __init__(self, config): @@ -788,11 +582,12 @@ class HubertAttnAdapterLayer(nn.Module): class HubertEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() - self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = HubertAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -841,7 +636,6 @@ class HubertEncoderStableLayerNorm(nn.Module): [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -855,19 +649,14 @@ class HubertEncoderStableLayerNorm(nn.Module): all_self_attentions = () if output_attentions else None if attention_mask is not None: - # make sure padded tokens are not attended to + # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + hidden_states[~expand_attention_mask] = 0 + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -918,6 +707,28 @@ class HubertEncoderStableLayerNorm(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + @auto_docstring class HubertPreTrainedModel(PreTrainedModel): @@ -927,6 +738,7 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index b3e3d24cc0e..c0454452f02 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Hubert model.""" + from typing import Optional, Tuple, Union import torch @@ -115,6 +131,7 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index 37876080dfc..e226e15da19 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -211,7 +211,6 @@ class IdeficsProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["image_size", "add_end_of_utterance_token"] image_processor_class = "IdeficsImageProcessor" tokenizer_class = "LlamaTokenizerFast" diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index ab144f3f9de..5be15d8cd8b 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -85,7 +85,6 @@ class Idefics2Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["image_seq_len", "chat_template"] image_processor_class = "Idefics2ImageProcessor" tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/idefics3/image_processing_idefics3.py b/src/transformers/models/idefics3/image_processing_idefics3.py index b2f049e998a..e84c4157b2a 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3.py +++ b/src/transformers/models/idefics3/image_processing_idefics3.py @@ -850,5 +850,46 @@ class Idefics3ImageProcessor(BaseImageProcessor): return encoding + def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None): + """ + A utility that returns number of image patches for a given image size. + + Args: + height (`int`): + Height of the input image. + width (`int`): + Width of the input image. + images_kwargs (`dict`, *optional*) + Any kwargs to override defaults of the image processor. + Returns: + `int`: Number of patches per image. + """ + do_image_splitting = images_kwargs.get("do_image_splitting", None) or self.do_image_splitting + max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size + size = images_kwargs.get("size", None) or self.size + + if do_image_splitting: + height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"]) + height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=4096) + aspect_ratio = width / height + + if width >= height: + resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"] + resized_height = int(width / aspect_ratio) + resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"] + elif height > width: + resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"] + resized_width = int(height * aspect_ratio) + resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"] + + max_height = max_width = max_image_size["longest_edge"] + if resized_height > max_height or resized_width > max_width: + # Calculate the number of splits + num_rows = math.ceil(resized_height / max_height) + num_cols = math.ceil(resized_width / max_width) + num_patches = num_rows * num_cols + 1 + + return num_patches + __all__ = ["Idefics3ImageProcessor"] diff --git a/src/transformers/models/idefics3/processing_idefics3.py b/src/transformers/models/idefics3/processing_idefics3.py index 1fcce0a453a..5f4450df8b4 100644 --- a/src/transformers/models/idefics3/processing_idefics3.py +++ b/src/transformers/models/idefics3/processing_idefics3.py @@ -16,13 +16,16 @@ Processor class for Idefics3. """ +import math import re from itertools import accumulate from typing import TYPE_CHECKING, Dict, List, Optional, Union +import numpy as np + from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, load_image -from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput from ...utils import logging @@ -98,6 +101,7 @@ class Idefics3ProcessorKwargs(ProcessingKwargs, total=False): "add_special_tokens": True, "padding": False, "is_split_into_words": False, + "return_mm_token_type_ids": False, }, "images_kwargs": { "return_row_col_info": True, @@ -129,7 +133,6 @@ class Idefics3Processor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["image_seq_len", "chat_template"] image_processor_class = "Idefics3ImageProcessor" tokenizer_class = "AutoTokenizer" @@ -146,6 +149,12 @@ class Idefics3Processor(ProcessorMixin): self.end_of_utterance_token = AddedToken("", normalized=False, special=True).content self.global_image_tag = "" # https://github.com/huggingface/transformers/pull/32473/files/8063e5e17362571b693f1db95167f5443a3be1b2#r1734825341 self.image_seq_len = image_seq_len + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.fake_image_token_id = tokenizer.convert_tokens_to_ids(self.fake_image_token) + self.global_image_token_id = tokenizer.convert_tokens_to_ids(self.global_image_tag) + self.row_col_ids = [ + tokenizer.convert_tokens_to_ids(f"") for i in range(6) for j in range(6) + ] # This regex matches one or more occurrences of tags (optionally surrounded by newline characters) # or tags (where x and y are digits, also optionally surrounded by newline characters). @@ -241,6 +250,7 @@ class Idefics3Processor(ProcessorMixin): ) image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) n_images_in_text = [] @@ -302,9 +312,11 @@ class Idefics3Processor(ProcessorMixin): global_img_token = self.global_image_tag prompt_strings = [] + batch_image_seq_lengths = [] for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` image_prompt_strings = [] + image_seq_lengths = [] for n_rows, n_cols in zip(sample_rows, sample_cols): image_prompt_string = get_image_prompt_string( n_rows, @@ -314,8 +326,12 @@ class Idefics3Processor(ProcessorMixin): fake_token_around_image=fake_image_token, global_img_token=global_img_token, ) + # Add +2 and +3 for special BOI/EOI/fake_image_wrapper tokens + row_length = (self.image_seq_len + 2) * n_cols + 1 + image_seq_lengths.append((self.image_seq_len + 3) + row_length * n_rows) image_prompt_strings.append(image_prompt_string) + batch_image_seq_lengths.append(image_seq_lengths) split_sample = sample.split(image_token) if len(split_sample) == 0: raise ValueError("The image token should be present in the text.") @@ -338,7 +354,59 @@ class Idefics3Processor(ProcessorMixin): text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) inputs.update(text_inputs) - return BatchFeature(inputs, tensor_type=return_tensors) + if return_mm_token_type_ids: + array_ids = np.array(inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(array_ids) + for i, seq_lengths in enumerate(batch_image_seq_lengths): + image_start_positions = np.where(array_ids[i] == self.fake_image_token_id)[0] + j = 0 + for seq_len in seq_lengths: + if j >= len(image_start_positions): + break + start = image_start_positions[j] + end = start + seq_len + mm_token_type_ids[i, start:end] = 1 + j = np.searchsorted(image_start_positions, end) + + inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data=inputs, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + + Args: + image_sizes (`List[List[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = Idefics3ProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + + base_image_length = self.image_seq_len + 3 + col_length = self.image_seq_len + 2 + num_image_tokens = [] + + for num_patches in num_image_patches: + num_cols = num_rows = int(math.sqrt(num_patches - 1)) + row_length = col_length * num_cols + 1 + num_image_tokens.append(base_image_length + (row_length * num_rows)) + + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + return MultiModalData(**vision_data) def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 8b728a19dff..330bc620bc0 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/informer/modular_informer.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_informer.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. # @@ -12,19 +18,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Informer model.""" -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -32,19 +41,20 @@ from ...modeling_outputs import ( Seq2SeqTSModelOutput, Seq2SeqTSPredictionOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import ( - auto_docstring, - logging, -) +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_informer import InformerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesFeatureEmbedder with TimeSeries->Informer class InformerFeatureEmbedder(nn.Module): """ Embed a sequence of categorical features. @@ -79,7 +89,6 @@ class InformerFeatureEmbedder(nn.Module): ) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerStdScaler(nn.Module): """ Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by @@ -115,7 +124,6 @@ class InformerStdScaler(nn.Module): return (data - loc) / scale, loc, scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerMeanScaler(nn.Module): """ Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data @@ -170,7 +178,6 @@ class InformerMeanScaler(nn.Module): return scaled_data, torch.zeros_like(scale), scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerNOPScaler(nn.Module): """ Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. @@ -198,40 +205,6 @@ class InformerNOPScaler(nn.Module): return data, loc, scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average -def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: - """ - Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, - meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. - - Args: - input_tensor (`torch.FloatTensor`): - Input tensor, of which the average must be computed. - weights (`torch.FloatTensor`, *optional*): - Weights tensor, of the same shape as `input_tensor`. - dim (`int`, *optional*): - The dim along which to average `input_tensor`. - - Returns: - `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. - """ - if weights is not None: - weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) - sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) - return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights - else: - return input_tensor.mean(dim=dim) - - -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll -def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: - """ - Computes the negative log likelihood loss from input distribution with respect to target. - """ - return -input.log_prob(target) - - -# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Informer class InformerSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" @@ -266,7 +239,6 @@ class InformerSinusoidalPositionalEmbedding(nn.Embedding): return super().forward(position_ids) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesValueEmbedding with TimeSeries->Info class InformerValueEmbedding(nn.Module): def __init__(self, feature_size, d_model): super().__init__() @@ -276,7 +248,156 @@ class InformerValueEmbedding(nn.Module): return self.value_projection(x) -# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Informer +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config_class = InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class InformerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -289,6 +410,7 @@ class InformerAttention(nn.Module): bias: bool = True, is_causal: bool = False, config: Optional[InformerConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -305,23 +427,31 @@ class InformerAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -329,110 +459,69 @@ class InformerAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class InformerProbSparseAttention(nn.Module): @@ -448,6 +537,7 @@ class InformerProbSparseAttention(nn.Module): is_decoder: bool = False, sampling_factor: int = 5, bias: bool = True, + layer_idx: Optional[int] = None, ): super().__init__() self.factor = sampling_factor @@ -463,6 +553,7 @@ class InformerProbSparseAttention(nn.Module): ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -480,6 +571,7 @@ class InformerProbSparseAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -488,45 +580,43 @@ class InformerProbSparseAttention(nn.Module): is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) @@ -681,6 +771,14 @@ class InformerEncoderLayer(nn.Module): def __init__(self, config: InformerConfig): super().__init__() self.embed_dim = config.d_model + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + if config.attention_type == "prob": self.self_attn = InformerProbSparseAttention( embed_dim=self.embed_dim, @@ -693,14 +791,8 @@ class InformerEncoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, @@ -754,25 +846,9 @@ class InformerEncoderLayer(nn.Module): class InformerDecoderLayer(nn.Module): - def __init__(self, config: InformerConfig): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model - - if config.attention_type == "prob": - self.self_attn = InformerProbSparseAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - sampling_factor=config.sampling_factor, - is_decoder=True, - ) - else: - self.self_attn = InformerAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout @@ -783,12 +859,33 @@ class InformerDecoderLayer(nn.Module): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + is_decoder=True, + layer_idx=layer_idx, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + def forward( self, hidden_states: torch.Tensor, @@ -797,9 +894,10 @@ class InformerDecoderLayer(nn.Module): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -818,47 +916,43 @@ class InformerDecoderLayer(nn.Module): output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -874,36 +968,15 @@ class InformerDecoderLayer(nn.Module): outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs -@auto_docstring -class InformerPreTrainedModel(PreTrainedModel): - config_class = InformerConfig - base_model_prefix = "model" - main_input_name = "past_values" - supports_gradient_checkpointing = True - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, InformerSinusoidalPositionalEmbedding): - module._init_weight() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - class InformerEncoder(InformerPreTrainedModel): """ - Informer encoder consisting of *config.encoder_layers* self attention layers with distillation layers. Each - attention layer is an [`InformerEncoderLayer`]. + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`InformerEncoderLayer`]. Args: config: InformerConfig @@ -914,7 +987,6 @@ class InformerEncoder(InformerPreTrainedModel): self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop - self.gradient_checkpointing = False if config.prediction_length is None: raise ValueError("The `prediction_length` config needs to be specified.") @@ -924,6 +996,7 @@ class InformerEncoder(InformerPreTrainedModel): ) self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.gradient_checkpointing = False if config.distil: self.conv_layers = nn.ModuleList( @@ -932,7 +1005,6 @@ class InformerEncoder(InformerPreTrainedModel): self.conv_layers.append(None) else: self.conv_layers = [None] * config.encoder_layers - # Initialize weights and apply final processing self.post_init() @@ -1053,7 +1125,7 @@ class InformerEncoder(InformerPreTrainedModel): class InformerDecoder(InformerPreTrainedModel): """ - Informer decoder consisting of *config.decoder_layers* layers. Each layer is a + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`InformerDecoderLayer`] Args: @@ -1071,7 +1143,7 @@ class InformerDecoder(InformerPreTrainedModel): self.embed_positions = InformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ) - self.layers = nn.ModuleList([InformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1091,6 +1163,7 @@ class InformerDecoder(InformerPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1148,6 +1221,9 @@ class InformerDecoder(InformerPreTrainedModel): for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1157,20 +1233,35 @@ class InformerDecoder(InformerPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_shape = inputs_embeds.size()[:-1] - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device + ) + + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) hidden_states = self.value_embedding(inputs_embeds) embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) @@ -1188,7 +1279,7 @@ class InformerDecoder(InformerPreTrainedModel): all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1208,8 +1299,6 @@ class InformerDecoder(InformerPreTrainedModel): if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1222,6 +1311,7 @@ class InformerDecoder(InformerPreTrainedModel): None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1233,14 +1323,15 @@ class InformerDecoder(InformerPreTrainedModel): cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1253,6 +1344,9 @@ class InformerDecoder(InformerPreTrainedModel): all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1269,7 +1363,6 @@ class InformerDecoder(InformerPreTrainedModel): @auto_docstring -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerModel with TimeSeriesTransformer->Informer,TIME_SERIES_TRANSFORMER->INFORMER,time-series-transformer->informer,TimeSeries->Informer class InformerModel(InformerPreTrainedModel): def __init__(self, config: InformerConfig): super().__init__(config) @@ -1408,7 +1501,6 @@ class InformerModel(InformerPreTrainedModel): def get_decoder(self): return self.decoder - # Ignore copy @auto_docstring def forward( self, @@ -1429,6 +1521,7 @@ class InformerModel(InformerPreTrainedModel): output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1586,7 +1679,16 @@ class InformerModel(InformerPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - dec_input = transformer_inputs[:, self.config.context_length :, ...] + # Avoid empty tensors and instead create a zeroes tensor which + # will be treated the same in torch, i.e. matmul with empty == all 0s + if self.config.context_length >= transformer_inputs.shape[1]: + bsz, _, dim = transformer_inputs.shape + dec_input = torch.zeros( + size=(bsz, 1, dim), device=transformer_inputs.device, dtype=transformer_inputs.dtype + ) + else: + dec_input = transformer_inputs[:, self.config.context_length :, ...] + decoder_outputs = self.decoder( inputs_embeds=dec_input, attention_mask=decoder_attention_mask, @@ -1598,6 +1700,7 @@ class InformerModel(InformerPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1618,11 +1721,42 @@ class InformerModel(InformerPreTrainedModel): ) +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + @auto_docstring -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerForPrediction with TimeSeriesTransformer->Informer,TIME_SERIES_TRANSFORMER->INFORMER,time-series-transformer->informer class InformerForPrediction(InformerPreTrainedModel): def __init__(self, config: InformerConfig): super().__init__(config) + self.model = InformerModel(config) if config.distribution_output == "student_t": self.distribution_output = StudentTOutput(dim=config.input_size) @@ -1660,7 +1794,6 @@ class InformerForPrediction(InformerPreTrainedModel): sliced_params = [p[:, -trailing_n:] for p in params] return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale) - # Ignore copy @auto_docstring def forward( self, @@ -1682,6 +1815,7 @@ class InformerForPrediction(InformerPreTrainedModel): output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1853,6 +1987,7 @@ class InformerForPrediction(InformerPreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, return_dict=return_dict, + cache_position=cache_position, ) prediction_loss = None diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py new file mode 100644 index 00000000000..15bcb8d38a8 --- /dev/null +++ b/src/transformers/models/informer/modular_informer.py @@ -0,0 +1,997 @@ +# coding=utf-8 +# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Informer model.""" + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from ...cache_utils import EncoderDecoderCache +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, +) +from ..bart.modeling_bart import BartAttention +from ..time_series_transformer.modeling_time_series_transformer import ( + TimeSeriesFeatureEmbedder, + TimeSeriesMeanScaler, + TimeSeriesNOPScaler, + TimeSeriesSinusoidalPositionalEmbedding, + TimeSeriesStdScaler, + TimeSeriesTransformerDecoder, + TimeSeriesTransformerDecoderLayer, + TimeSeriesTransformerEncoder, + TimeSeriesTransformerEncoderLayer, + TimeSeriesTransformerForPrediction, + TimeSeriesTransformerModel, + TimeSeriesValueEmbedding, +) +from .configuration_informer import InformerConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +class InformerFeatureEmbedder(TimeSeriesFeatureEmbedder): + pass + + +class InformerStdScaler(TimeSeriesStdScaler): + pass + + +class InformerMeanScaler(TimeSeriesMeanScaler): + pass + + +class InformerNOPScaler(TimeSeriesNOPScaler): + pass + + +class InformerSinusoidalPositionalEmbedding(TimeSeriesSinusoidalPositionalEmbedding): + pass + + +class InformerValueEmbedding(TimeSeriesValueEmbedding): + pass + + +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config_class = InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class InformerAttention(BartAttention): + pass + + +class InformerProbSparseAttention(nn.Module): + """Probabilistic Attention mechanism to select the "active" + queries rather than the "lazy" queries and provides a sparse Transformer thus mitigating the quadratic compute and + memory requirements of vanilla attention""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + sampling_factor: int = 5, + bias: bool = True, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.factor = sampling_factor + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + key_states_time_length = key_states.size(1) # L_K + log_key_states_time_length = np.ceil(np.log1p(key_states_time_length)).astype("int").item() # log_L_K + + query_states_time_length = query_states.size(1) # L_Q + log_query_states_time_length = np.ceil(np.log1p(query_states_time_length)).astype("int").item() # log_L_Q + + u_part = min(self.factor * query_states_time_length * log_key_states_time_length, key_states_time_length) + u = min(self.factor * log_query_states_time_length, query_states_time_length) + + if key_states_time_length > 0: + index_sample = torch.randint(0, key_states_time_length, (u_part,)) + k_sample = key_states[:, index_sample, :] + else: + k_sample = key_states + + queries_keys_sample = torch.bmm(query_states, k_sample.transpose(1, 2)) # Q_K_sampled + + # find the Top_k query with sparsity measurement + if u > 0: + sparsity_measurement = queries_keys_sample.max(dim=-1)[0] - torch.div( + queries_keys_sample.sum(dim=-1), key_states_time_length + ) # M + top_u_sparsity_measurement = sparsity_measurement.topk(u, sorted=False)[1] # M_top + + # calculate q_reduce: query_states[:, top_u_sparsity_measurement] + dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1) + q_reduce = query_states[dim_for_slice, top_u_sparsity_measurement] + else: + q_reduce = query_states + top_u_sparsity_measurement = None + + # Use q_reduce to calculate attention weights + attn_weights = torch.bmm(q_reduce, key_states.transpose(1, 2)) + + src_len = key_states.size(1) + if attn_weights.size() != (bsz * self.num_heads, u, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, u, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + prob_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, src_len).reshape( + bsz * self.num_heads, tgt_len, src_len + ) + + if top_u_sparsity_measurement is not None: + dim_for_slice = torch.arange(prob_mask.size(0)).unsqueeze(-1) + prob_mask = prob_mask[dim_for_slice, top_u_sparsity_measurement, :] + + attn_weights = attn_weights.view(bsz, self.num_heads, u, src_len) + prob_mask.view( + bsz, self.num_heads, u, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, u, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.bmm(attn_probs, value_states) + + # calculate context for updating the attn_output, based on: + # https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74 + if self.is_decoder: + # cast to float32 before operation to avoid overflow + context = value_states.cumsum(dim=-2, dtype=torch.float32).to(value_states.dtype) + else: + v_mean_dim_time = value_states.mean(dim=-2) + context = ( + v_mean_dim_time.unsqueeze(dim=1) + .expand(bsz * self.num_heads, query_states_time_length, v_mean_dim_time.size(-1)) + .clone() + ) + + if top_u_sparsity_measurement is not None: + # update context: copy the attention output to the context at top_u_sparsity_measurement index + dim_for_slice = torch.arange(context.size(0)).unsqueeze(-1) + context[dim_for_slice, top_u_sparsity_measurement, :] = attn_output + attn_output = context + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py +class InformerConvLayer(nn.Module): + def __init__(self, c_in): + super().__init__() + self.downConv = nn.Conv1d( + in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=1, + padding_mode="circular", + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class InformerEncoderLayer(TimeSeriesTransformerEncoderLayer): + def __init__(self, config: InformerConfig): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class InformerDecoderLayer(TimeSeriesTransformerDecoderLayer): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + is_decoder=True, + layer_idx=layer_idx, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + + +class InformerEncoder(TimeSeriesTransformerEncoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.gradient_checkpointing = False + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + if config.distil: + self.conv_layers = nn.ModuleList( + [InformerConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)] + ) + self.conv_layers.append(None) + else: + self.conv_layers = [None] * config.encoder_layers + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size()) + + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, (encoder_layer, conv_layer) in enumerate(zip(self.layers, self.conv_layers)): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + if conv_layer is not None: + output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class InformerDecoder(TimeSeriesTransformerDecoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + +class InformerModel(TimeSeriesTransformerModel, nn.Module): + def __init__(self, config: InformerConfig): + nn.Module().__init__(config) + + if config.scaling == "mean" or config.scaling is True: + self.scaler = InformerMeanScaler(config) + elif config.scaling == "std": + self.scaler = InformerStdScaler(config) + else: + self.scaler = InformerNOPScaler(config) + + if config.num_static_categorical_features > 0: + self.embedder = InformerFeatureEmbedder( + cardinalities=config.cardinality, + embedding_dims=config.embedding_dimension, + ) + + # transformer encoder-decoder and mask initializer + self.encoder = InformerEncoder(config) + self.decoder = InformerDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly") + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + super().forward(**super_kwargs) + + +class InformerForPrediction(TimeSeriesTransformerForPrediction, nn.Module): + def __init__(self, config: InformerConfig): + nn.Module().__init__(config) + + self.model = InformerModel(config) + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.input_size) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.input_size) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.input_size) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model) + self.target_shape = self.distribution_output.event_shape + + if config.loss == "nll": + self.loss = nll + else: + raise ValueError(f"Unknown loss function {config.loss}") + + # Initialize weights of distribution_output and apply final processing + self.post_init() + + @auto_docstring + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + This mask is used to filter out missing values for the final loss calculation. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerForPrediction.from_pretrained( + ... "huggingface/informer-tourism-monthly" + ... ) + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values + >>> # as well as possible additional features + >>> # the model autoregressively generates future values + >>> outputs = model.generate( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> mean_prediction = outputs.sequences.mean(dim=1) + ```""" + super().forward(**super_kwargs) + + +__all__ = ["InformerForPrediction", "InformerModel", "InformerPreTrainedModel"] diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index c90d22f012d..8018dbe76a9 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -277,7 +278,7 @@ class InstructBlipMLP(nn.Module): # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip -class InstructBlipEncoderLayer(nn.Module): +class InstructBlipEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: InstructBlipConfig): super().__init__() self.embed_dim = config.hidden_size @@ -423,19 +424,12 @@ class InstructBlipEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -744,7 +738,7 @@ class InstructBlipQFormerOutput(nn.Module): return hidden_states -class InstructBlipQFormerLayer(nn.Module): +class InstructBlipQFormerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -885,31 +879,22 @@ class InstructBlipQFormerEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) + use_cache = False + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/instructblip/processing_instructblip.py b/src/transformers/models/instructblip/processing_instructblip.py index 408dfbd0756..d3df6f4ef90 100644 --- a/src/transformers/models/instructblip/processing_instructblip.py +++ b/src/transformers/models/instructblip/processing_instructblip.py @@ -22,12 +22,7 @@ from typing import List, Union from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack -from ...tokenization_utils_base import ( - AddedToken, - BatchEncoding, - PreTokenizedInput, - TextInput, -) +from ...tokenization_utils_base import AddedToken, BatchEncoding, PreTokenizedInput, TextInput from ...utils import logging from ..auto import AutoTokenizer @@ -72,7 +67,6 @@ class InstructBlipProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer", "qformer_tokenizer"] - valid_kwargs = ["num_query_tokens"] image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast") tokenizer_class = "AutoTokenizer" qformer_tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index b9f40deffef..cc18bbf90b6 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -29,6 +29,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -247,7 +248,7 @@ class InstructBlipVideoMLP(nn.Module): return hidden_states -class InstructBlipVideoEncoderLayer(nn.Module): +class InstructBlipVideoEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: InstructBlipVideoConfig): super().__init__() self.embed_dim = config.hidden_size @@ -352,19 +353,12 @@ class InstructBlipVideoEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -606,7 +600,7 @@ class InstructBlipVideoQFormerOutput(nn.Module): return hidden_states -class InstructBlipVideoQFormerLayer(nn.Module): +class InstructBlipVideoQFormerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -746,31 +740,22 @@ class InstructBlipVideoQFormerEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + if getattr(self.config, "gradient_checkpointing", False) and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) + use_cache = False + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py index 8c59606e4b6..fad69b72e2f 100644 --- a/src/transformers/models/instructblipvideo/processing_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/processing_instructblipvideo.py @@ -57,7 +57,6 @@ class InstructBlipVideoProcessor(ProcessorMixin): """ attributes = ["video_processor", "tokenizer", "qformer_tokenizer"] - valid_kwargs = ["num_query_tokens"] video_processor_class = "AutoVideoProcessor" tokenizer_class = "AutoTokenizer" qformer_tokenizer_class = "AutoTokenizer" diff --git a/src/transformers/models/internvl/__init__.py b/src/transformers/models/internvl/__init__.py index 26514250827..6d4ffe7befa 100644 --- a/src/transformers/models/internvl/__init__.py +++ b/src/transformers/models/internvl/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .configuration_internvl import * from .modeling_internvl import * from .processing_internvl import * + from .video_processing_internvl import * else: import sys diff --git a/src/transformers/models/internvl/convert_internvl_weights_to_hf.py b/src/transformers/models/internvl/convert_internvl_weights_to_hf.py index f687a2e7146..fa6d4bc9e52 100644 --- a/src/transformers/models/internvl/convert_internvl_weights_to_hf.py +++ b/src/transformers/models/internvl/convert_internvl_weights_to_hf.py @@ -28,6 +28,7 @@ from transformers import ( InternVLConfig, InternVLForConditionalGeneration, InternVLProcessor, + InternVLVideoProcessor, InternVLVisionConfig, LlamaConfig, Qwen2Config, @@ -56,7 +57,7 @@ UNNECESSARY_CONFIG_KEYS = [ "_name_or_path", "_attn_implementation_autoset", "au # fmt: off ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION = { # Vision encoder mapping - r"vision_model": r"vision_tower", + r"vision_model": r"model.vision_tower", r"layers": r"layer", r"class_embedding": r"cls_token", r"position_embedding": r"position_embeddings", @@ -71,7 +72,7 @@ ORIGINAL_TO_CONVERTED_KEY_MAPPING_VISION = { } ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA = { - # Vision encoder mapping + r"language_model.model.": r"model.language_model.", r"tok_embeddings": r"embed_tokens", r"attention.wo": r"self_attn.o_proj", r"feed_forward.w1": r"mlp.gate_proj", @@ -79,14 +80,20 @@ ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_LLAMA = { r"feed_forward.w3": r"mlp.up_proj", r"attention_norm": r"input_layernorm", r"ffn_norm": r"post_attention_layernorm", - r"output": r"lm_head", + r"language_model.output": r"lm_head", +} + +ORIGINAL_TO_CONVERTED_KEY_MAPPING_TEXT_QWEN2 = { + # Vision encoder mapping + r"language_model.model.": r"model.language_model.", + r"language_model.lm_head": r"lm_head", } ORIGINAL_TO_CONVERTED_KEY_MAPPING_MULTI = { # Vision encoder mapping - r"mlp1.0": r"multi_modal_projector.layer_norm", - r"mlp1.1": r"multi_modal_projector.linear_1", - r"mlp1.3": r"multi_modal_projector.linear_2", + r"mlp1.0": r"model.multi_modal_projector.layer_norm", + r"mlp1.1": r"model.multi_modal_projector.linear_1", + r"mlp1.3": r"model.multi_modal_projector.linear_2", } @@ -98,7 +105,7 @@ chat_template = ( "{% else %}" "{% for content in message['content'] %}" "{% if content['type'] == 'image' %}" - "{{ '\n' }}" + "{{ '\n' }}" "{% elif content['type'] == 'video' %}" "{{ '