mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
parent
f5620a7634
commit
a7f5479b45
@ -1,40 +1,48 @@
|
||||
import ast
|
||||
from collections import defaultdict, deque
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
# Function to perform topological sorting
|
||||
def topological_sort(dependencies):
|
||||
# Create a graph and in-degree count for each node
|
||||
new_dependencies = {}
|
||||
graph = defaultdict(list)
|
||||
in_degree = defaultdict(int)
|
||||
|
||||
# Build the graph
|
||||
for node, deps in dependencies.items():
|
||||
for dep in deps:
|
||||
graph[dep].append(node) # node depends on dep
|
||||
in_degree[node] += 1 # increase in-degree of node
|
||||
if "example" not in node and "auto" not in dep:
|
||||
graph[dep.split(".")[-2]].append(node.split("/")[-2])
|
||||
new_dependencies[node.split("/")[-2]] = node
|
||||
|
||||
# Add all nodes with zero in-degree to the queue
|
||||
zero_in_degree_queue = deque([node for node in dependencies if in_degree[node] == 0])
|
||||
# Create a graph and in-degree count for each node
|
||||
def filter_one_by_one(filtered_list, reverse):
|
||||
if len(reverse) == 0:
|
||||
return filtered_list
|
||||
|
||||
sorted_list = []
|
||||
# Perform topological sorting
|
||||
while zero_in_degree_queue:
|
||||
current = zero_in_degree_queue.popleft()
|
||||
sorted_list.append(current)
|
||||
graph = defaultdict(list)
|
||||
# Build the graph
|
||||
for node, deps in reverse.items():
|
||||
for dep in deps:
|
||||
graph[dep].append(node)
|
||||
|
||||
# For each node that current points to, reduce its in-degree
|
||||
for neighbor in graph[current]:
|
||||
in_degree[neighbor] -= 1
|
||||
if in_degree[neighbor] == 0:
|
||||
zero_in_degree_queue.append(neighbor)
|
||||
base_modules = set(reverse.keys()) - set(graph.keys())
|
||||
if base_modules == reverse.keys():
|
||||
# we are at the end
|
||||
return filtered_list + list(graph.keys())
|
||||
to_add = []
|
||||
for k in graph.keys():
|
||||
if len(graph[k]) == 1 and graph[k][0] in base_modules:
|
||||
if graph[k][0] in reverse:
|
||||
del reverse[graph[k][0]]
|
||||
if k not in filtered_list:
|
||||
to_add += [k]
|
||||
for k in base_modules:
|
||||
if k not in filtered_list:
|
||||
to_add += [k]
|
||||
filtered_list += list(to_add)
|
||||
return filter_one_by_one(filtered_list, reverse)
|
||||
|
||||
# Handle nodes that have no dependencies and were not initially part of the loop
|
||||
for node in dependencies:
|
||||
if node not in sorted_list:
|
||||
sorted_list.append(node)
|
||||
final_order = filter_one_by_one([], graph)
|
||||
|
||||
return sorted_list
|
||||
return [new_dependencies.get(k) for k in final_order if k in new_dependencies]
|
||||
|
||||
|
||||
# Function to extract class and import info from a file
|
||||
@ -46,7 +54,7 @@ def extract_classes_and_imports(file_path):
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
module = node.module if isinstance(node, ast.ImportFrom) else None
|
||||
if module and "transformers" in module:
|
||||
if module and (".modeling_" in module):
|
||||
imports.add(module)
|
||||
return imports
|
||||
|
||||
@ -56,7 +64,7 @@ def map_dependencies(py_files):
|
||||
dependencies = defaultdict(set)
|
||||
# First pass: Extract all classes and map to files
|
||||
for file_path in py_files:
|
||||
dependencies[file_path].add(None)
|
||||
# dependencies[file_path].add(None)
|
||||
class_to_file = extract_classes_and_imports(file_path)
|
||||
for module in class_to_file:
|
||||
dependencies[file_path].add(module)
|
||||
@ -66,4 +74,4 @@ def map_dependencies(py_files):
|
||||
def find_priority_list(py_files):
|
||||
dependencies = map_dependencies(py_files)
|
||||
ordered_classes = topological_sort(dependencies)
|
||||
return ordered_classes[::-1]
|
||||
return ordered_classes
|
||||
|
@ -1678,7 +1678,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--files_to_parse",
|
||||
default=["src/transformers/models/aria/modular_aria.py"],
|
||||
default=["all"],
|
||||
nargs="+",
|
||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user