Fix the check of models supporting FA/SDPA not run (#28202)

* add check_support_list.py

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-12-22 12:56:11 +01:00 committed by GitHub
parent e37ab52dff
commit bb3bd44739
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 58 deletions

View File

@ -209,6 +209,7 @@ jobs:
- run: python utils/update_metadata.py --check-only - run: python utils/update_metadata.py --check-only
- run: python utils/check_task_guides.py - run: python utils/check_task_guides.py
- run: python utils/check_docstrings.py - run: python utils/check_docstrings.py
- run: python utils/check_support_list.py
workflows: workflows:
version: 2 version: 2

View File

@ -44,6 +44,7 @@ repo-consistency:
python utils/update_metadata.py --check-only python utils/update_metadata.py --check-only
python utils/check_task_guides.py python utils/check_task_guides.py
python utils/check_docstrings.py python utils/check_docstrings.py
python utils/check_support_list.py
# this target runs checks on all files # this target runs checks on all files

View File

@ -16,7 +16,6 @@ import doctest
import logging import logging
import os import os
import unittest import unittest
from glob import glob
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
@ -27,63 +26,6 @@ from transformers.testing_utils import require_tf, require_torch, slow
logger = logging.getLogger() logger = logging.getLogger()
@require_torch
class TestDocLists(unittest.TestCase):
def test_flash_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
doctext = doctext.split("You can request to add FlashAttention-2 support")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_fa2 = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_flash_attn_2 = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_fa2.append(model_name)
for arch in archs_supporting_fa2:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
)
def test_sdpa_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split(
"For now, Transformers supports SDPA inference and training for the following architectures:"
)[1]
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_sdpa = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_sdpa = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)
@unittest.skip("Temporarily disable the doc tests.") @unittest.skip("Temporarily disable the doc tests.")
@require_torch @require_torch
@require_tf @require_tf

View File

@ -0,0 +1,95 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that checks the supports of 3rd party libraries are listed in the documentation file. Currently, this includes:
- flash attention support
- SDPA support
Use from the root of the repo with (as used in `make repo-consistency`):
```bash
python utils/check_support_list.py
```
It has no auto-fix mode.
"""
import os
from glob import glob
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_doctest_list.py
REPO_PATH = "."
def check_flash_support_list():
with open(os.path.join(REPO_PATH, "docs/source/en/perf_infer_gpu_one.md"), "r") as f:
doctext = f.read()
doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
doctext = doctext.split("You can request to add FlashAttention-2 support")[0]
patterns = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_*.py"))
patterns_tf = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_tf_*.py"))
patterns_flax = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_flax_*.py"))
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_fa2 = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_flash_attn_2 = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_fa2.append(model_name)
for arch in archs_supporting_fa2:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
)
def check_sdpa_support_list():
with open(os.path.join(REPO_PATH, "docs/source/en/perf_infer_gpu_one.md"), "r") as f:
doctext = f.read()
doctext = doctext.split(
"For now, Transformers supports SDPA inference and training for the following architectures:"
)[1]
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
patterns = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_*.py"))
patterns_tf = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_tf_*.py"))
patterns_flax = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_flax_*.py"))
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_sdpa = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_sdpa = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)
if __name__ == "__main__":
check_flash_support_list()
check_sdpa_support_list()