Fix the test fetcher (#37452)

Test fetcher
This commit is contained in:
Lysandre Debut 2025-04-11 12:19:27 +02:00 committed by GitHub
parent 442d356aa5
commit f797e3d98a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 11 deletions

View File

@ -2378,6 +2378,7 @@ def spread_import_structure(nested_import_structure):
return flattened_import_structure
@lru_cache()
def define_import_structure(module_path: str, prefix: str = None) -> IMPORT_STRUCTURE_T:
"""
This method takes a module_path as input and creates an import structure digestible by a _LazyModule.

View File

@ -736,7 +736,7 @@ def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = Non
# the object is fully defined in the __init__)
if module.endswith("__init__.py"):
# So we get the imports from that init then try to find where our objects come from.
new_imported_modules = extract_imports(module, cache=cache)
new_imported_modules = dict(extract_imports(module, cache=cache))
# Add imports via `define_import_structure` after the #35167 as we remove explicit import in `__init__.py`
from transformers.utils.import_utils import define_import_structure
@ -749,9 +749,15 @@ def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = Non
# We replace with os.path.sep so that it's Windows-compatible
_module = _module.replace(".", os.path.sep)
_module = module.replace("__init__.py", f"{_module}.py")
new_imported_modules.append((_module, list(_imports)))
if _module not in new_imported_modules:
new_imported_modules[_module] = list(_imports)
else:
original_imports = new_imported_modules[_module]
for potential_new_item in list(_imports):
if potential_new_item not in original_imports:
new_imported_modules[_module].append(potential_new_item)
for new_module, new_imports in new_imported_modules:
for new_module, new_imports in new_imported_modules.items():
if any(i in new_imports for i in imports):
if new_module not in dependencies:
new_modules.append((new_module, [i for i in new_imports if i in imports]))
@ -1041,18 +1047,17 @@ def infer_tests_to_run(
"""
if not test_all:
modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
reverse_map = create_reverse_dependency_map()
impacted_files = modified_files.copy()
for f in modified_files:
if f in reverse_map:
impacted_files.extend(reverse_map[f])
else:
impacted_files = modified_files = [
str(k) for k in PATH_TO_TESTS.glob("*/*") if str(k).endswith(".py") and "test_" in str(k)
]
modified_files = [str(k) for k in PATH_TO_TESTS.glob("*/*") if str(k).endswith(".py") and "test_" in str(k)]
print("\n### test_all is TRUE, FETCHING ALL FILES###\n")
print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")
reverse_map = create_reverse_dependency_map()
impacted_files = modified_files.copy()
for f in modified_files:
if f in reverse_map:
impacted_files.extend(reverse_map[f])
# Remove duplicates
impacted_files = sorted(set(impacted_files))
print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")