mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add documentation to dynamic module utils (#25534)
* Add documentation to dynamic module utils * Address review comments
This commit is contained in:
parent
d1832dd808
commit
297a6a7aea
@ -20,9 +20,10 @@ import re
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .utils import (
|
||||
HF_MODULES_CACHE,
|
||||
@ -57,6 +58,10 @@ def init_hf_modules():
|
||||
def create_dynamic_module(name: Union[str, os.PathLike]):
|
||||
"""
|
||||
Creates a dynamic module in the cache directory for modules.
|
||||
|
||||
Args:
|
||||
name (`str` or `os.PathLike`):
|
||||
The name of the dynamic module to create.
|
||||
"""
|
||||
init_hf_modules()
|
||||
dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
|
||||
@ -67,15 +72,20 @@ def create_dynamic_module(name: Union[str, os.PathLike]):
|
||||
init_path = dynamic_module_path / "__init__.py"
|
||||
if not init_path.exists():
|
||||
init_path.touch()
|
||||
# It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
|
||||
# with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
|
||||
importlib.invalidate_caches()
|
||||
|
||||
|
||||
def get_relative_imports(module_file):
|
||||
def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
|
||||
"""
|
||||
Get the list of modules that are relatively imported in a module file.
|
||||
|
||||
Args:
|
||||
module_file (`str` or `os.PathLike`): The module file to inspect.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of relative imports in the module.
|
||||
"""
|
||||
with open(module_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
@ -88,13 +98,17 @@ def get_relative_imports(module_file):
|
||||
return list(set(relative_imports))
|
||||
|
||||
|
||||
def get_relative_import_files(module_file):
|
||||
def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
|
||||
"""
|
||||
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
|
||||
imports (if a imports b and b imports c, it will return module files for b and c).
|
||||
|
||||
Args:
|
||||
module_file (`str` or `os.PathLike`): The module file to inspect.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
|
||||
of module files a given module needs.
|
||||
"""
|
||||
no_change = False
|
||||
files_to_check = [module_file]
|
||||
@ -117,9 +131,15 @@ def get_relative_import_files(module_file):
|
||||
return all_relative_imports
|
||||
|
||||
|
||||
def get_imports(filename):
|
||||
def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
||||
"""
|
||||
Extracts all the libraries that are imported in a file.
|
||||
Extracts all the libraries (not relative imports this time) that are imported in a file.
|
||||
|
||||
Args:
|
||||
filename (`str` or `os.PathLike`): The module file to inspect.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of all packages required to use the input module.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
@ -136,9 +156,16 @@ def get_imports(filename):
|
||||
return list(set(imports))
|
||||
|
||||
|
||||
def check_imports(filename):
|
||||
def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
||||
"""
|
||||
Check if the current Python environment contains all the libraries that are imported in a file.
|
||||
Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
|
||||
library is missing.
|
||||
|
||||
Args:
|
||||
filename (`str` or `os.PathLike`): The module file to check.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of relative imports in the file.
|
||||
"""
|
||||
imports = get_imports(filename)
|
||||
missing_packages = []
|
||||
@ -157,9 +184,16 @@ def check_imports(filename):
|
||||
return get_relative_imports(filename)
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path):
|
||||
def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type:
|
||||
"""
|
||||
Import a module on the cache directory for modules and extract a class from it.
|
||||
|
||||
Args:
|
||||
class_name (`str`): The name of the class to import.
|
||||
module_path (`str` or `os.PathLike`): The path to the module to import.
|
||||
|
||||
Returns:
|
||||
`typing.Type`: The class looked for.
|
||||
"""
|
||||
module_path = module_path.replace(os.path.sep, ".")
|
||||
module = importlib.import_module(module_path)
|
||||
@ -179,7 +213,7 @@ def get_cached_module_file(
|
||||
repo_type: Optional[str] = None,
|
||||
_commit_hash: Optional[str] = None,
|
||||
**deprecated_kwargs,
|
||||
):
|
||||
) -> str:
|
||||
"""
|
||||
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
||||
Transformers module.
|
||||
@ -354,7 +388,7 @@ def get_class_from_dynamic_module(
|
||||
repo_type: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> typing.Type:
|
||||
"""
|
||||
Extracts a class from a module file, present in the local folder or repository of a model.
|
||||
|
||||
@ -416,7 +450,7 @@ def get_class_from_dynamic_module(
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`type`: The class, dynamically imported from the module.
|
||||
`typing.Type`: The class, dynamically imported from the module.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -463,7 +497,7 @@ def get_class_from_dynamic_module(
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||
|
||||
|
||||
def custom_object_save(obj, folder, config=None):
|
||||
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
|
||||
"""
|
||||
Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
|
||||
adds the proper fields in a config.
|
||||
@ -473,6 +507,9 @@ def custom_object_save(obj, folder, config=None):
|
||||
folder (`str` or `os.PathLike`): The folder where to save.
|
||||
config (`PretrainedConfig` or dictionary, `optional`):
|
||||
A config in which to register the auto_map corresponding to this custom object.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of files saved.
|
||||
"""
|
||||
if obj.__module__ == "__main__":
|
||||
logger.warning(
|
||||
|
Loading…
Reference in New Issue
Block a user