diff --git a/.circleci/config.yml b/.circleci/config.yml index cdd97f4fcec..6558dc1454b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -142,6 +142,7 @@ jobs: - run: python utils/custom_init_isort.py --check_only - run: python utils/sort_auto_mappings.py --check_only - run: python utils/check_doc_toc.py + - run: python utils/check_docstrings.py --check_all check_repository_consistency: working_directory: ~/transformers @@ -190,4 +191,4 @@ workflows: - check_circleci_user - check_code_quality - check_repository_consistency - - fetch_all_tests \ No newline at end of file + - fetch_all_tests diff --git a/Makefile b/Makefile index f9b2a8c9a7c..cfa40b7bd6e 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,7 @@ quality: python utils/custom_init_isort.py --check_only python utils/sort_auto_mappings.py --check_only python utils/check_doc_toc.py + python utils/check_docstrings.py --check_all # Format source code automatically and check is there are any problems left that need manual fixing diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1cb9fcf5cc2..e465b0e08d8 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -112,6 +112,7 @@ class CacheConfig: Args: config_dict (Dict[str, Any]): Dictionary containing configuration parameters. **kwargs: Additional keyword arguments to override dictionary values. + Returns: CacheConfig: Instance of CacheConfig constructed from the dictionary. """ diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index b67920daaf8..f57427c4f65 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -43,10 +43,12 @@ from pathlib import Path from typing import Any, Optional, Tuple, Union from check_repo import ignore_undocumented +from git import Repo from transformers.utils import direct_transformers_import +PATH_TO_REPO = Path(__file__).parent.parent.resolve() PATH_TO_TRANSFORMERS = Path("src").resolve() / "transformers" # This is to make sure the transformers module imported is the one in the repo. @@ -943,14 +945,33 @@ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str): f.write("\n".join(lines)) -def check_docstrings(overwrite: bool = False): +def check_docstrings(overwrite: bool = False, check_all: bool = False): """ - Check docstrings of all public objects that are callables and are documented. + Check docstrings of all public objects that are callables and are documented. By default, only checks the diff. Args: overwrite (`bool`, *optional*, defaults to `False`): Whether to fix inconsistencies or not. + check_all (`bool`, *optional*, defaults to `False`): + Whether to check all files. """ + module_diff_files = None + if not check_all: + module_diff_files = set() + repo = Repo(PATH_TO_REPO) + # Diff from index to unstaged files + for modified_file_diff in repo.index.diff(None): + if modified_file_diff.a_path.startswith("src/transformers"): + module_diff_files.add(modified_file_diff.a_path) + # Diff from index to `main` + for modified_file_diff in repo.index.diff(repo.refs.main.commit): + if modified_file_diff.a_path.startswith("src/transformers"): + module_diff_files.add(modified_file_diff.a_path) + # quick escape route: if there are no module files in the diff, skip this check + if len(module_diff_files) == 0: + return + print(" Checking docstrings in the following files:" + "\n - " + "\n - ".join(module_diff_files)) + failures = [] hard_failures = [] to_clean = [] @@ -963,6 +984,13 @@ def check_docstrings(overwrite: bool = False): if not callable(obj) or not isinstance(obj, type) or getattr(obj, "__doc__", None) is None: continue + # If we are checking against the diff, we skip objects that are not part of the diff. + if module_diff_files is not None: + object_file = find_source_file(getattr(transformers, name)) + object_file_relative_path = "src/" + str(object_file).split("/src/")[1] + if object_file_relative_path not in module_diff_files: + continue + # Check docstring try: result = match_docstring_with_signature(obj) @@ -1013,6 +1041,9 @@ def check_docstrings(overwrite: bool = False): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") + parser.add_argument( + "--check_all", action="store_true", help="Whether to check all files. By default, only checks the diff" + ) args = parser.parse_args() - check_docstrings(overwrite=args.fix_and_overwrite) + check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)