Repo checks: skip docstring checks if not in the diff (#32328)

* tmp

* skip files not in the diff

* use git.Repo instead of an external subprocess

* add tiny change to confirm that the diff is working on pushed changes

* add make quality task

* more profesh main commit reference
This commit is contained in:
Joao Gante 2024-07-30 18:56:10 +01:00 committed by GitHub
parent 516af4bb63
commit 026a173a64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 38 additions and 4 deletions

View File

@ -142,6 +142,7 @@ jobs:
- run: python utils/custom_init_isort.py --check_only
- run: python utils/sort_auto_mappings.py --check_only
- run: python utils/check_doc_toc.py
- run: python utils/check_docstrings.py --check_all
check_repository_consistency:
working_directory: ~/transformers
@ -190,4 +191,4 @@ workflows:
- check_circleci_user
- check_code_quality
- check_repository_consistency
- fetch_all_tests
- fetch_all_tests

View File

@ -56,6 +56,7 @@ quality:
python utils/custom_init_isort.py --check_only
python utils/sort_auto_mappings.py --check_only
python utils/check_doc_toc.py
python utils/check_docstrings.py --check_all
# Format source code automatically and check is there are any problems left that need manual fixing

View File

@ -112,6 +112,7 @@ class CacheConfig:
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.
Returns:
CacheConfig: Instance of CacheConfig constructed from the dictionary.
"""

View File

@ -43,10 +43,12 @@ from pathlib import Path
from typing import Any, Optional, Tuple, Union
from check_repo import ignore_undocumented
from git import Repo
from transformers.utils import direct_transformers_import
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
PATH_TO_TRANSFORMERS = Path("src").resolve() / "transformers"
# This is to make sure the transformers module imported is the one in the repo.
@ -943,14 +945,33 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str):
f.write("\n".join(lines))
def check_docstrings(overwrite: bool = False):
def check_docstrings(overwrite: bool = False, check_all: bool = False):
"""
Check docstrings of all public objects that are callables and are documented.
Check docstrings of all public objects that are callables and are documented. By default, only checks the diff.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether to fix inconsistencies or not.
check_all (`bool`, *optional*, defaults to `False`):
Whether to check all files.
"""
module_diff_files = None
if not check_all:
module_diff_files = set()
repo = Repo(PATH_TO_REPO)
# Diff from index to unstaged files
for modified_file_diff in repo.index.diff(None):
if modified_file_diff.a_path.startswith("src/transformers"):
module_diff_files.add(modified_file_diff.a_path)
# Diff from index to `main`
for modified_file_diff in repo.index.diff(repo.refs.main.commit):
if modified_file_diff.a_path.startswith("src/transformers"):
module_diff_files.add(modified_file_diff.a_path)
# quick escape route: if there are no module files in the diff, skip this check
if len(module_diff_files) == 0:
return
print(" Checking docstrings in the following files:" + "\n - " + "\n - ".join(module_diff_files))
failures = []
hard_failures = []
to_clean = []
@ -963,6 +984,13 @@ def check_docstrings(overwrite: bool = False):
if not callable(obj) or not isinstance(obj, type) or getattr(obj, "__doc__", None) is None:
continue
# If we are checking against the diff, we skip objects that are not part of the diff.
if module_diff_files is not None:
object_file = find_source_file(getattr(transformers, name))
object_file_relative_path = "src/" + str(object_file).split("/src/")[1]
if object_file_relative_path not in module_diff_files:
continue
# Check docstring
try:
result = match_docstring_with_signature(obj)
@ -1013,6 +1041,9 @@ def check_docstrings(overwrite: bool = False):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
parser.add_argument(
"--check_all", action="store_true", help="Whether to check all files. By default, only checks the diff"
)
args = parser.parse_args()
check_docstrings(overwrite=args.fix_and_overwrite)
check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)