mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
import ast
|
|
from collections import defaultdict
|
|
|
|
|
|
# Function to perform topological sorting
|
|
def topological_sort(dependencies: dict):
|
|
# Nodes are the name of the models to convert (we only add those to the graph)
|
|
nodes = {node.rsplit("modular_", 1)[1].replace(".py", "") for node in dependencies.keys()}
|
|
# This will be a graph from models to convert, to models to convert that should be converted before (as they are a dependency)
|
|
graph = {}
|
|
name_mapping = {}
|
|
for node, deps in dependencies.items():
|
|
node_name = node.rsplit("modular_", 1)[1].replace(".py", "")
|
|
dep_names = {dep.split(".")[-2] for dep in deps}
|
|
dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name}
|
|
graph[node_name] = dependencies
|
|
name_mapping[node_name] = node
|
|
|
|
sorting_list = []
|
|
while len(graph) > 0:
|
|
# Find the nodes with 0 out-degree
|
|
leaf_nodes = {node for node in graph if len(graph[node]) == 0}
|
|
# Add them to the list
|
|
sorting_list += list(leaf_nodes)
|
|
# Remove the leafs from the graph (and from the deps of other nodes)
|
|
graph = {node: deps - leaf_nodes for node, deps in graph.items() if node not in leaf_nodes}
|
|
|
|
return [name_mapping[x] for x in sorting_list]
|
|
|
|
|
|
# Function to extract class and import info from a file
|
|
def extract_classes_and_imports(file_path):
|
|
with open(file_path, "r", encoding="utf-8") as file:
|
|
tree = ast.parse(file.read(), filename=file_path)
|
|
imports = set()
|
|
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
|
module = node.module if isinstance(node, ast.ImportFrom) else None
|
|
if module and (".modeling_" in module or "transformers.models" in module):
|
|
imports.add(module)
|
|
return imports
|
|
|
|
|
|
# Function to map dependencies between classes
|
|
def map_dependencies(py_files):
|
|
dependencies = defaultdict(set)
|
|
# First pass: Extract all classes and map to files
|
|
for file_path in py_files:
|
|
# dependencies[file_path].add(None)
|
|
class_to_file = extract_classes_and_imports(file_path)
|
|
for module in class_to_file:
|
|
dependencies[file_path].add(module)
|
|
return dependencies
|
|
|
|
|
|
def find_priority_list(py_files):
|
|
"""
|
|
Given a list of modular files, sorts them by topological order. Modular models that DON'T depend on other modular
|
|
models will be higher in the topological order.
|
|
|
|
Args:
|
|
py_files: List of paths to the modular files
|
|
|
|
Returns:
|
|
A tuple with the ordered files (list) and their dependencies (dict)
|
|
"""
|
|
dependencies = map_dependencies(py_files)
|
|
ordered_files = topological_sort(dependencies)
|
|
return ordered_files, dependencies
|