Add documentation to dynamic module utils (#25534)

* Add documentation to dynamic module utils

* Address review comments
This commit is contained in:
Sylvain Gugger 2023-08-17 08:28:06 +02:00 committed by GitHub
parent d1832dd808
commit 297a6a7aea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,9 +20,10 @@ import re
import shutil
import signal
import sys
import typing
import warnings
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
from .utils import (
HF_MODULES_CACHE,
@ -57,6 +58,10 @@ def init_hf_modules():
def create_dynamic_module(name: Union[str, os.PathLike]):
"""
Creates a dynamic module in the cache directory for modules.
Args:
name (`str` or `os.PathLike`):
The name of the dynamic module to create.
"""
init_hf_modules()
dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
@ -67,15 +72,20 @@ def create_dynamic_module(name: Union[str, os.PathLike]):
init_path = dynamic_module_path / "__init__.py"
if not init_path.exists():
init_path.touch()
# It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
# with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
importlib.invalidate_caches()
def get_relative_imports(module_file):
def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
"""
Get the list of modules that are relatively imported in a module file.
Args:
module_file (`str` or `os.PathLike`): The module file to inspect.
Returns:
`List[str]`: The list of relative imports in the module.
"""
with open(module_file, "r", encoding="utf-8") as f:
content = f.read()
@ -88,13 +98,17 @@ def get_relative_imports(module_file):
return list(set(relative_imports))
def get_relative_import_files(module_file):
def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
"""
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
imports (if a imports b and b imports c, it will return module files for b and c).
Args:
module_file (`str` or `os.PathLike`): The module file to inspect.
Returns:
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
of module files a given module needs.
"""
no_change = False
files_to_check = [module_file]
@ -117,9 +131,15 @@ def get_relative_import_files(module_file):
return all_relative_imports
def get_imports(filename):
def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
"""
Extracts all the libraries that are imported in a file.
Extracts all the libraries (not relative imports this time) that are imported in a file.
Args:
filename (`str` or `os.PathLike`): The module file to inspect.
Returns:
`List[str]`: The list of all packages required to use the input module.
"""
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
@ -136,9 +156,16 @@ def get_imports(filename):
return list(set(imports))
def check_imports(filename):
def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
"""
Check if the current Python environment contains all the libraries that are imported in a file.
Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
library is missing.
Args:
filename (`str` or `os.PathLike`): The module file to check.
Returns:
`List[str]`: The list of relative imports in the file.
"""
imports = get_imports(filename)
missing_packages = []
@ -157,9 +184,16 @@ def check_imports(filename):
return get_relative_imports(filename)
def get_class_in_module(class_name, module_path):
def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type:
"""
Import a module on the cache directory for modules and extract a class from it.
Args:
class_name (`str`): The name of the class to import.
module_path (`str` or `os.PathLike`): The path to the module to import.
Returns:
`typing.Type`: The class looked for.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
@ -179,7 +213,7 @@ def get_cached_module_file(
repo_type: Optional[str] = None,
_commit_hash: Optional[str] = None,
**deprecated_kwargs,
):
) -> str:
"""
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
Transformers module.
@ -354,7 +388,7 @@ def get_class_from_dynamic_module(
repo_type: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs,
):
) -> typing.Type:
"""
Extracts a class from a module file, present in the local folder or repository of a model.
@ -416,7 +450,7 @@ def get_class_from_dynamic_module(
</Tip>
Returns:
`type`: The class, dynamically imported from the module.
`typing.Type`: The class, dynamically imported from the module.
Examples:
@ -463,7 +497,7 @@ def get_class_from_dynamic_module(
return get_class_in_module(class_name, final_module.replace(".py", ""))
def custom_object_save(obj, folder, config=None):
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
"""
Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
adds the proper fields in a config.
@ -473,6 +507,9 @@ def custom_object_save(obj, folder, config=None):
folder (`str` or `os.PathLike`): The folder where to save.
config (`PretrainedConfig` or dictionary, `optional`):
A config in which to register the auto_map corresponding to this custom object.
Returns:
`List[str]`: The list of files saved.
"""
if obj.__module__ == "__main__":
logger.warning(