fix modular order (#35297)

* fix modular ordre

* fix

* style
This commit is contained in:
Arthur 2024-12-17 08:05:35 +01:00 committed by GitHub
parent f5620a7634
commit a7f5479b45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 28 deletions

View File

@ -1,40 +1,48 @@
import ast
from collections import defaultdict, deque
from collections import defaultdict
# Function to perform topological sorting
def topological_sort(dependencies):
# Create a graph and in-degree count for each node
new_dependencies = {}
graph = defaultdict(list)
in_degree = defaultdict(int)
# Build the graph
for node, deps in dependencies.items():
for dep in deps:
graph[dep].append(node) # node depends on dep
in_degree[node] += 1 # increase in-degree of node
if "example" not in node and "auto" not in dep:
graph[dep.split(".")[-2]].append(node.split("/")[-2])
new_dependencies[node.split("/")[-2]] = node
# Add all nodes with zero in-degree to the queue
zero_in_degree_queue = deque([node for node in dependencies if in_degree[node] == 0])
# Create a graph and in-degree count for each node
def filter_one_by_one(filtered_list, reverse):
if len(reverse) == 0:
return filtered_list
sorted_list = []
# Perform topological sorting
while zero_in_degree_queue:
current = zero_in_degree_queue.popleft()
sorted_list.append(current)
graph = defaultdict(list)
# Build the graph
for node, deps in reverse.items():
for dep in deps:
graph[dep].append(node)
# For each node that current points to, reduce its in-degree
for neighbor in graph[current]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
zero_in_degree_queue.append(neighbor)
base_modules = set(reverse.keys()) - set(graph.keys())
if base_modules == reverse.keys():
# we are at the end
return filtered_list + list(graph.keys())
to_add = []
for k in graph.keys():
if len(graph[k]) == 1 and graph[k][0] in base_modules:
if graph[k][0] in reverse:
del reverse[graph[k][0]]
if k not in filtered_list:
to_add += [k]
for k in base_modules:
if k not in filtered_list:
to_add += [k]
filtered_list += list(to_add)
return filter_one_by_one(filtered_list, reverse)
# Handle nodes that have no dependencies and were not initially part of the loop
for node in dependencies:
if node not in sorted_list:
sorted_list.append(node)
final_order = filter_one_by_one([], graph)
return sorted_list
return [new_dependencies.get(k) for k in final_order if k in new_dependencies]
# Function to extract class and import info from a file
@ -46,7 +54,7 @@ def extract_classes_and_imports(file_path):
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
module = node.module if isinstance(node, ast.ImportFrom) else None
if module and "transformers" in module:
if module and (".modeling_" in module):
imports.add(module)
return imports
@ -56,7 +64,7 @@ def map_dependencies(py_files):
dependencies = defaultdict(set)
# First pass: Extract all classes and map to files
for file_path in py_files:
dependencies[file_path].add(None)
# dependencies[file_path].add(None)
class_to_file = extract_classes_and_imports(file_path)
for module in class_to_file:
dependencies[file_path].add(module)
@ -66,4 +74,4 @@ def map_dependencies(py_files):
def find_priority_list(py_files):
dependencies = map_dependencies(py_files)
ordered_classes = topological_sort(dependencies)
return ordered_classes[::-1]
return ordered_classes

View File

@ -1678,7 +1678,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/aria/modular_aria.py"],
default=["all"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)