mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
parent
442d356aa5
commit
f797e3d98a
@ -2378,6 +2378,7 @@ def spread_import_structure(nested_import_structure):
|
||||
return flattened_import_structure
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def define_import_structure(module_path: str, prefix: str = None) -> IMPORT_STRUCTURE_T:
|
||||
"""
|
||||
This method takes a module_path as input and creates an import structure digestible by a _LazyModule.
|
||||
|
@ -736,7 +736,7 @@ def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = Non
|
||||
# the object is fully defined in the __init__)
|
||||
if module.endswith("__init__.py"):
|
||||
# So we get the imports from that init then try to find where our objects come from.
|
||||
new_imported_modules = extract_imports(module, cache=cache)
|
||||
new_imported_modules = dict(extract_imports(module, cache=cache))
|
||||
|
||||
# Add imports via `define_import_structure` after the #35167 as we remove explicit import in `__init__.py`
|
||||
from transformers.utils.import_utils import define_import_structure
|
||||
@ -749,9 +749,15 @@ def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = Non
|
||||
# We replace with os.path.sep so that it's Windows-compatible
|
||||
_module = _module.replace(".", os.path.sep)
|
||||
_module = module.replace("__init__.py", f"{_module}.py")
|
||||
new_imported_modules.append((_module, list(_imports)))
|
||||
if _module not in new_imported_modules:
|
||||
new_imported_modules[_module] = list(_imports)
|
||||
else:
|
||||
original_imports = new_imported_modules[_module]
|
||||
for potential_new_item in list(_imports):
|
||||
if potential_new_item not in original_imports:
|
||||
new_imported_modules[_module].append(potential_new_item)
|
||||
|
||||
for new_module, new_imports in new_imported_modules:
|
||||
for new_module, new_imports in new_imported_modules.items():
|
||||
if any(i in new_imports for i in imports):
|
||||
if new_module not in dependencies:
|
||||
new_modules.append((new_module, [i for i in new_imports if i in imports]))
|
||||
@ -1041,18 +1047,17 @@ def infer_tests_to_run(
|
||||
"""
|
||||
if not test_all:
|
||||
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
|
||||
reverse_map = create_reverse_dependency_map()
|
||||
impacted_files = modified_files.copy()
|
||||
for f in modified_files:
|
||||
if f in reverse_map:
|
||||
impacted_files.extend(reverse_map[f])
|
||||
else:
|
||||
impacted_files = modified_files = [
|
||||
str(k) for k in PATH_TO_TESTS.glob("*/*") if str(k).endswith(".py") and "test_" in str(k)
|
||||
]
|
||||
modified_files = [str(k) for k in PATH_TO_TESTS.glob("*/*") if str(k).endswith(".py") and "test_" in str(k)]
|
||||
print("\n### test_all is TRUE, FETCHING ALL FILES###\n")
|
||||
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
|
||||
|
||||
reverse_map = create_reverse_dependency_map()
|
||||
impacted_files = modified_files.copy()
|
||||
for f in modified_files:
|
||||
if f in reverse_map:
|
||||
impacted_files.extend(reverse_map[f])
|
||||
|
||||
# Remove duplicates
|
||||
impacted_files = sorted(set(impacted_files))
|
||||
print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")
|
||||
|
Loading…
Reference in New Issue
Block a user