diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 82e038dc28e..3c751bceaac 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -2378,6 +2378,7 @@ def spread_import_structure(nested_import_structure): return flattened_import_structure +@lru_cache() def define_import_structure(module_path: str, prefix: str = None) -> IMPORT_STRUCTURE_T: """ This method takes a module_path as input and creates an import structure digestible by a _LazyModule. diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 3c02c7be62a..e2a256dfd6f 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -736,7 +736,7 @@ def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = Non # the object is fully defined in the __init__) if module.endswith("__init__.py"): # So we get the imports from that init then try to find where our objects come from. - new_imported_modules = extract_imports(module, cache=cache) + new_imported_modules = dict(extract_imports(module, cache=cache)) # Add imports via `define_import_structure` after the #35167 as we remove explicit import in `__init__.py` from transformers.utils.import_utils import define_import_structure @@ -749,9 +749,15 @@ def get_module_dependencies(module_fname: str, cache: Dict[str, List[str]] = Non # We replace with os.path.sep so that it's Windows-compatible _module = _module.replace(".", os.path.sep) _module = module.replace("__init__.py", f"{_module}.py") - new_imported_modules.append((_module, list(_imports))) + if _module not in new_imported_modules: + new_imported_modules[_module] = list(_imports) + else: + original_imports = new_imported_modules[_module] + for potential_new_item in list(_imports): + if potential_new_item not in original_imports: + new_imported_modules[_module].append(potential_new_item) - for new_module, new_imports in new_imported_modules: + for new_module, new_imports in new_imported_modules.items(): if any(i in new_imports for i in imports): if new_module not in dependencies: new_modules.append((new_module, [i for i in new_imports if i in imports])) @@ -1041,18 +1047,17 @@ def infer_tests_to_run( """ if not test_all: modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit) - reverse_map = create_reverse_dependency_map() - impacted_files = modified_files.copy() - for f in modified_files: - if f in reverse_map: - impacted_files.extend(reverse_map[f]) else: - impacted_files = modified_files = [ - str(k) for k in PATH_TO_TESTS.glob("*/*") if str(k).endswith(".py") and "test_" in str(k) - ] + modified_files = [str(k) for k in PATH_TO_TESTS.glob("*/*") if str(k).endswith(".py") and "test_" in str(k)] print("\n### test_all is TRUE, FETCHING ALL FILES###\n") print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}") + reverse_map = create_reverse_dependency_map() + impacted_files = modified_files.copy() + for f in modified_files: + if f in reverse_map: + impacted_files.extend(reverse_map[f]) + # Remove duplicates impacted_files = sorted(set(impacted_files)) print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")