mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00

* initial design * update all video processors * add tests * need to add qwen2-vl (not tested yet) * add qwen2-vl in auto map * fix copies * isort * resolve confilicts kinda * nit: * qwen2-vl is happy now * qwen2-5 happy * other models are happy * fix copies * fix tests * add docs * CI green now? * add more tests * even more changes + tests * doc builder fail * nit * Update src/transformers/models/auto/processing_auto.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * small update * imports correctly * dump, otherwise this is getting unmanagebale T-T * dump * update * another update * update * tests * move * modular * docs * test * another update * init * remove flakiness in tests * fixup * clean up and remove commented lines * docs * skip this one! * last fix after rebasing * run fixup * delete slow files * remove unnecessary tests + clean up a bit * small fixes * fix tests * more updates * docs * fix tests * update * style * fix qwen2-5-vl * fixup * fixup * unflatten batch when preparing * dump, come back soon * add docs and fix some tests * how to guard this with new dummies? * chat templates in qwen * address some comments * remove `Fast` suffix * fixup * oops should be imported from transforms * typo in requires dummies * new model added with video support * fixup once more * last fixup I hope * revert image processor name + comments * oh, this is why fetch test is failing * fix tests * fix more tests * fixup * add new models: internvl, smolvlm * update docs * imprt once * fix failing tests * do we need to guard it here again, why? * new model was added, update it * remove testcase from tester * fix tests * make style * not related CI fail, lets' just fix here * mark flaky for now, filas 15 out of 100 * style * maybe we can do this way? * don't download images in setup class --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1812 lines
95 KiB
Python
1812 lines
95 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import glob
|
|
import importlib
|
|
import os
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from collections import Counter, defaultdict, deque
|
|
from typing import Dict, Optional, Set, Union
|
|
|
|
import libcst as cst
|
|
from check_copies import run_ruff
|
|
from create_dependency_mapping import find_priority_list
|
|
from libcst import ClassDef, CSTVisitor
|
|
from libcst import matchers as m
|
|
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
|
|
|
|
from transformers import logging
|
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# This file was automatically generated from {relative_path}.
|
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
|
# the file from the modular. If any change should be done, please apply the change to the
|
|
# {short_name} file directly. One of our CI enforces this.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
"""
|
|
|
|
|
|
def get_module_source_from_name(module_name: str) -> str:
|
|
# Extract the source code from the module name
|
|
spec = importlib.util.find_spec(module_name)
|
|
if spec is None or spec.origin is None:
|
|
raise ValueError(f"Cannot open file associated with {module_name} module.")
|
|
|
|
with open(spec.origin, "r", encoding="utf-8") as file:
|
|
source_code = file.read()
|
|
return source_code
|
|
|
|
|
|
def preserve_case_replace(text, patterns: dict, default_name: str):
|
|
# Create a regex pattern to match all variations
|
|
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
|
|
compiled_regex = re.compile(f"(?<![a-z0-9])({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)
|
|
|
|
def replace(match):
|
|
matched_pattern = match.group(1)
|
|
next_char = match.group(2)
|
|
new_pattern = patterns.get(matched_pattern, default_name)
|
|
|
|
# In this case, the cased old model did not respect CamelCase and was all UPPERCASE, so we need to rely on next char
|
|
# The heuristic is: if next char is not a letter, then it is not part of a model name and result should be `new_name`.upper()
|
|
if len(patterns) == 2 and matched_pattern.isupper():
|
|
if not next_char.isalpha():
|
|
# `new_name.upper()` is just the other entry for `matched_pattern.lower()`, uppercased
|
|
new_pattern = patterns[matched_pattern.lower()].upper()
|
|
|
|
return new_pattern + next_char
|
|
|
|
return compiled_regex.sub(replace, text)
|
|
|
|
|
|
def get_cased_name(lowercase_name: str) -> str:
|
|
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`."""
|
|
alt_lowercase_name = lowercase_name.replace("_", "-")
|
|
if lowercase_name in CONFIG_MAPPING_NAMES:
|
|
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "")
|
|
elif alt_lowercase_name in CONFIG_MAPPING_NAMES:
|
|
return CONFIG_MAPPING_NAMES[alt_lowercase_name].replace("Config", "")
|
|
else:
|
|
return "".join(x.title() for x in lowercase_name.split("_"))
|
|
|
|
|
|
def get_lowercase_name(cased_name: str) -> str:
|
|
"""From a model name in Camelcase in the format `MyModel`, return the lowercase name in the format `my_model`."""
|
|
inverse_mapping = {value: key for key, value in CONFIG_MAPPING_NAMES.items()}
|
|
if cased_name + "Config" in inverse_mapping:
|
|
return inverse_mapping[cased_name + "Config"]
|
|
else:
|
|
return "_".join([s.lower() for s in re.findall(r"[A-Z][^A-Z]*", cased_name)])
|
|
|
|
|
|
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
|
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references.
|
|
It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING.
|
|
Supported renaming patterns:
|
|
- llama -> my_new_model and my_new_model -> llama
|
|
- Llama -> MyNewModel and MyNewModel -> Llama
|
|
- LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA
|
|
- LLaMa -> MyNewModel abd MyNewModel -> Llama
|
|
"""
|
|
|
|
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False):
|
|
super().__init__()
|
|
old_name = old_name.replace("-", "_")
|
|
new_name = new_name.replace("-", "_")
|
|
self.old_name = old_name
|
|
self.new_name = new_name
|
|
self.cased_new_name = get_cased_name(self.new_name)
|
|
self.cased_old_name = get_cased_name(self.old_name)
|
|
self.patterns = {
|
|
old_name: new_name,
|
|
old_name.upper(): new_name.upper(),
|
|
# For some old models, `self.cased_old_name` == `old_name.upper()` in which case this overwrite previous entry
|
|
self.cased_old_name: self.cased_new_name,
|
|
}
|
|
# In case new_name is a prefix alias, and not the original new model name
|
|
self.original_new_model_name = original_new_model_name
|
|
self.only_doc = only_doc
|
|
|
|
def _replace_name(self, original_node, updated_node):
|
|
if re.findall(r"# Copied from", updated_node.value):
|
|
return cst.RemoveFromParent()
|
|
update = preserve_case_replace(updated_node.value, self.patterns, self.cased_new_name)
|
|
return updated_node.with_changes(value=update)
|
|
|
|
@m.leave(m.SimpleString() | m.Comment())
|
|
def replace_name(self, original_node, updated_node):
|
|
return self._replace_name(original_node, updated_node)
|
|
|
|
def leave_Name(self, original_node, updated_node):
|
|
if not self.only_doc:
|
|
return self._replace_name(original_node, updated_node)
|
|
return updated_node
|
|
|
|
def leave_ImportFrom(self, original_node, updated_node):
|
|
"""The imports from other file types (configuration, processing etc) should use original model name."""
|
|
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()):
|
|
patterns = "|".join(ALL_FILE_TYPES)
|
|
regex = rf"({patterns})_{self.new_name}"
|
|
new_source = re.sub(
|
|
regex, lambda m: f"{m.group(1)}_{self.original_new_model_name}", updated_node.module.value
|
|
)
|
|
updated_node = updated_node.with_changes(module=updated_node.module.with_changes(value=new_source))
|
|
return updated_node
|
|
|
|
|
|
DOCSTRING_NODE = m.SimpleStatementLine(
|
|
body=[
|
|
m.Expr(
|
|
value=m.SimpleString(
|
|
# match anything between """ """
|
|
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
|
|
)
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
def SUPER_CALL_NODE(func_name):
|
|
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
|
|
|
|
|
def is_call_to_super(node, func_name):
|
|
return m.matches(
|
|
node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))])
|
|
)
|
|
|
|
|
|
def get_full_attribute_name(node: Union[cst.Attribute, cst.Name]) -> Optional[str]:
|
|
"""Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the
|
|
successive value of an Attribute are not Name nodes, return `None`."""
|
|
if m.matches(node, m.Name()):
|
|
return node.value
|
|
elif m.matches(node, m.Attribute()):
|
|
if not m.matches(node.attr, m.Name()):
|
|
return None
|
|
name = node.attr.value
|
|
new_node = node.value
|
|
while m.matches(new_node, m.Attribute()):
|
|
if not m.matches(new_node.attr, m.Name()):
|
|
return None
|
|
name = new_node.attr.value + "." + name
|
|
new_node = new_node.value
|
|
if not m.matches(new_node, m.Name()):
|
|
return None
|
|
return new_node.value + "." + name
|
|
return None
|
|
|
|
|
|
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
|
|
class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
|
def __init__(self, all_bases: Set[str]):
|
|
self.all_bases = all_bases
|
|
|
|
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
|
|
# Handle ClassB.call_to_method or module.classB.call_to_method
|
|
if (
|
|
m.matches(original_node.value, m.Name() | m.Attribute())
|
|
and get_full_attribute_name(original_node.value) in self.all_bases
|
|
and m.matches(original_node.attr, m.Name())
|
|
):
|
|
# Replace with super().call_to_method
|
|
return updated_node.with_changes(
|
|
value=cst.Call(cst.Name("super")),
|
|
)
|
|
# Handle ClassB().call_to_method or module.ClassB().call_to_method
|
|
elif (
|
|
m.matches(original_node.value, m.Call())
|
|
and m.matches(original_node.value.func, m.Name() | m.Attribute())
|
|
and get_full_attribute_name(original_node.value.func) in self.all_bases
|
|
and m.matches(original_node.attr, m.Name())
|
|
):
|
|
# Replace with super().call_to_method
|
|
return updated_node.with_changes(value=cst.Call(cst.Name("super")))
|
|
return updated_node
|
|
|
|
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
|
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
|
|
if m.matches(original_node.func, m.Attribute()) and (
|
|
# Match ClassB().func_a(...) or module
|
|
(
|
|
m.matches(original_node.func.value, m.Call())
|
|
and m.matches(original_node.func.value.func, m.Name() | m.Attribute())
|
|
and get_full_attribute_name(original_node.func.value.func) in self.all_bases
|
|
)
|
|
or
|
|
# Match ClassB.func_a(...)
|
|
(
|
|
m.matches(original_node.func.value, m.Name() | m.Attribute())
|
|
and get_full_attribute_name(original_node.func.value) in self.all_bases
|
|
)
|
|
):
|
|
# Check if the first argument is 'self', and remove it
|
|
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
|
|
# Create the new argument list without 'self'
|
|
new_args = updated_node.args[1:]
|
|
else:
|
|
new_args = updated_node.args
|
|
|
|
return updated_node.with_changes(args=new_args)
|
|
return updated_node
|
|
|
|
|
|
def get_docstring_indent(docstring):
|
|
# Match the first line after the opening triple quotes
|
|
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
|
|
if match:
|
|
# Return the indentation spaces captured
|
|
return len(match.group(1))
|
|
return 0
|
|
|
|
|
|
def is_full_docstring(original_docstring: str, new_docstring: str, original_level: int) -> bool:
|
|
"""Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then
|
|
be merged with the existing old one.
|
|
"""
|
|
# libcst returns the docstrinbgs with literal `r"""` quotes in front
|
|
new_docstring = new_docstring.split('"""', 1)[1]
|
|
# The docstring contains Args definition, so it is self-contained
|
|
if re.search(r"\n\s*Args:\n", new_docstring):
|
|
return True
|
|
elif re.search(r"\n\s*Args:\n", original_docstring):
|
|
return False
|
|
# Check if the docstring contains args docstring (meaning it is self contained):
|
|
param_pattern = re.compile(
|
|
# |--- Group 1 ---|| Group 2 ||- Group 3 -||---------- Group 4 ----------|
|
|
rf"^\s{{0,{original_level}}}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{{0,{original_level}}}\w+\s*\().)*)",
|
|
re.DOTALL | re.MULTILINE,
|
|
)
|
|
match_object = param_pattern.search(new_docstring)
|
|
if match_object is not None:
|
|
return True
|
|
# If it contains Returns, but starts with text indented with an additional 4 spaces before, it is self-contained
|
|
# (this is the scenario when using `@add_start_docstrings_to_model_forward`, but adding more args to docstring)
|
|
match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring)
|
|
if match_object is not None:
|
|
full_indent = match_object.group(1)
|
|
striped_doc = new_docstring.strip("\n")
|
|
if striped_doc.startswith(full_indent + " " * 4) or striped_doc.startswith(full_indent + "\t"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def merge_docstrings(original_docstring, updated_docstring):
|
|
original_level = get_docstring_indent(original_docstring)
|
|
if not is_full_docstring(original_docstring, updated_docstring, original_level):
|
|
# Split the docstring at the example section, assuming `"""` is used to define the docstring
|
|
parts = original_docstring.split("```")
|
|
if "```" in updated_docstring and len(parts) > 1:
|
|
updated_docstring = updated_docstring.lstrip('r"')
|
|
new_parts = updated_docstring.split("```")
|
|
if len(new_parts) != 3:
|
|
raise ValueError("There should only be one example, and it should have opening and closing '```'")
|
|
parts[1] = new_parts[1]
|
|
updated_docstring = "".join(
|
|
[
|
|
f"\n{original_level * ' '}```",
|
|
parts[1],
|
|
"```",
|
|
parts[2],
|
|
]
|
|
)
|
|
docstring_opening, original_start_docstring = parts[0].rstrip(" \n").split('"""')[:2]
|
|
new_start_docstring = new_parts[0].rstrip(" \n")
|
|
docstring_opening += '"""'
|
|
if new_start_docstring.startswith(original_start_docstring):
|
|
updated_docstring = new_start_docstring + "\n" + updated_docstring
|
|
elif original_start_docstring.endswith(new_start_docstring):
|
|
updated_docstring = original_start_docstring + "\n" + updated_docstring
|
|
else:
|
|
updated_docstring = original_start_docstring + "\n" + new_start_docstring + "\n" + updated_docstring
|
|
updated_docstring = docstring_opening + updated_docstring
|
|
elif updated_docstring not in original_docstring:
|
|
# add tabulation if we are at the lowest level.
|
|
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
|
|
updated_docstring = updated_docstring.replace("\n ", "\n ")
|
|
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
|
|
return updated_docstring
|
|
|
|
|
|
class SuperTransformer(cst.CSTTransformer):
|
|
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
|
|
|
def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None):
|
|
self.python_module = python_module
|
|
self.original_methods = original_methods
|
|
self.updated_methods = updated_methods
|
|
self.all_assign_target = {}
|
|
self.deleted_targets = {} # child node can delete some arguments
|
|
self.all_bases = all_bases or []
|
|
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))
|
|
|
|
def update_body(self, existing_body, new_statements):
|
|
"""
|
|
Helper method to update the body by removing duplicates before adding new statements.
|
|
`existing_body` is the body of the original method, the parent class
|
|
`new_statements` are the additional statements
|
|
"""
|
|
deduplicated_new_body = []
|
|
existing_nodes = set()
|
|
for node in new_statements:
|
|
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])):
|
|
target = self.python_module.code_for_node(node.body[0].targets[0].target)
|
|
self.all_assign_target[target] = node
|
|
if m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
|
|
target = self.python_module.code_for_node(node.body[0].target)
|
|
self.deleted_targets[target] = node
|
|
|
|
for stmt in existing_body:
|
|
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
|
|
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
|
|
if target in self.deleted_targets:
|
|
continue
|
|
if target in self.all_assign_target:
|
|
stmt = self.all_assign_target[target]
|
|
# Skip the docstring (will be added later on, at the beginning)
|
|
elif m.matches(stmt, DOCSTRING_NODE):
|
|
continue
|
|
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
|
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
|
deduplicated_new_body.append(stmt)
|
|
existing_nodes.add(comment_less_code)
|
|
|
|
for node in new_statements:
|
|
code = self.python_module.code_for_node(node)
|
|
comment_less_code = re.sub(r"#.*", "", code).strip()
|
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
|
if node not in deduplicated_new_body and comment_less_code not in existing_nodes:
|
|
if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
|
|
deduplicated_new_body.append(node)
|
|
existing_nodes.add(comment_less_code)
|
|
|
|
deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body)
|
|
|
|
return deduplicated_new_body
|
|
|
|
def _fix_post_init_location(self, new_body: list[cst.CSTNode]):
|
|
"""Fix the location of the `post_init()` in the new body, if we added statements after the call to
|
|
`super()` (it needs to be the very last statement called)"""
|
|
# Fix the post_init() that has to be last
|
|
for i, node in enumerate(new_body):
|
|
code = self.python_module.code_for_node(node)
|
|
comment_less_code = re.sub(r"#.*", "", code).strip()
|
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
|
if "self.post_init(" in comment_less_code and i < len(new_body) - 1:
|
|
# Remove it and add it again at the end
|
|
new_body.pop(i)
|
|
new_body.append(node)
|
|
break
|
|
return new_body
|
|
|
|
def _fix_init_location(self, new_body):
|
|
"""Fix the location of the `super().__init__()` in the new body, if we had new statements before it."""
|
|
start_index = 0
|
|
for i, node in enumerate(new_body):
|
|
if m.matches(node, DOCSTRING_NODE) and i == start_index:
|
|
start_index += 1
|
|
continue
|
|
code = self.python_module.code_for_node(node)
|
|
comment_less_code = re.sub(r"#.*", "", code).strip()
|
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
|
|
if "super().__init__" in comment_less_code and i > start_index:
|
|
# Remove it and add it again at the top after the docstrings
|
|
node = new_body.pop(i)
|
|
new_body = new_body[:start_index] + [node] + new_body[start_index:]
|
|
break
|
|
return new_body
|
|
|
|
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
|
|
"""Updates the body of the input `node`'s `func_name` function by replacing calls
|
|
to super().func_name() with the source code of the parent class' `func_name`.
|
|
It keeps everything that is defined before `super().func_name()`.
|
|
"""
|
|
self.has_docstring = False
|
|
parent_has_docstring = False
|
|
if func_name in self.original_methods:
|
|
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
|
|
new_body = []
|
|
has_super_call = False
|
|
|
|
for i, expr in enumerate(node.body):
|
|
if is_call_to_super(expr, func_name):
|
|
has_super_call = True
|
|
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :]))
|
|
new_body = self._fix_init_location(new_body)
|
|
else:
|
|
expr = expr.visit(self.transformer)
|
|
if m.matches(expr, DOCSTRING_NODE):
|
|
self.has_docstring = True
|
|
if parent_has_docstring: # actually here we ought to de-duplicate?
|
|
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
|
|
updated_docstring = expr.body[0].value.value
|
|
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
|
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
|
|
else:
|
|
new_node = [expr]
|
|
new_body.extend(new_node)
|
|
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
|
|
new_body.append(expr)
|
|
if not self.has_docstring and parent_has_docstring:
|
|
new_body = [self.original_methods[func_name].body.body[0]] + new_body
|
|
return node.with_changes(body=new_body)
|
|
|
|
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
|
if updated_node.name.value in self.updated_methods:
|
|
name = updated_node.name.value
|
|
new_body = self.replace_super_calls(updated_node.body, name)
|
|
return updated_node.with_changes(body=new_body, params=updated_node.params)
|
|
return updated_node
|
|
|
|
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
|
|
""" "When a return statement is reached, it is replaced with the unrolled super code"""
|
|
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
|
|
func_def = self.get_metadata(ParentNodeProvider, original_node)
|
|
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
|
|
updated_return_value = updated_node.value.with_changes(
|
|
args=[
|
|
cst.Arg(
|
|
value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))])
|
|
)
|
|
]
|
|
)
|
|
return updated_node.with_changes(value=updated_return_value)
|
|
return updated_node
|
|
|
|
|
|
def find_all_dependencies(
|
|
dependency_mapping: Dict[str, set],
|
|
start_entity: Optional[str] = None,
|
|
initial_dependencies: Optional[set] = None,
|
|
initial_checked_dependencies: Optional[set] = None,
|
|
return_parent: bool = False,
|
|
) -> Union[list, set]:
|
|
"""Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of
|
|
BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`.
|
|
|
|
Args:
|
|
dependency_mapping (`Dict[str, set]`):
|
|
A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names,
|
|
a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called
|
|
in `foo`'s definition.
|
|
start_entity (str | None, *optional*):
|
|
A key of `dependency_mapping`, indicating from which entity to start the search.
|
|
initial_dependencies (set | None, *optional*):
|
|
If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue
|
|
from all the entities in `initial_dependencies`, if they are in `dependency_mapping`.
|
|
initial_checked_dependencies (set | None, *optional*):
|
|
If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies.
|
|
return_parent (bool, *optional*):
|
|
If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note
|
|
that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs.
|
|
Returns:
|
|
A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`.
|
|
|
|
Example:
|
|
Given the following structure in the `modular_xxx.py` file:
|
|
```
|
|
def foo1():
|
|
pass
|
|
|
|
def foo2():
|
|
pass
|
|
|
|
def bar():
|
|
foo1()
|
|
|
|
def foobar():
|
|
bar()
|
|
foo2()
|
|
|
|
class MyLayer(SomeOtherModelLayer):
|
|
def forward(...):
|
|
foobar()
|
|
```
|
|
and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get:
|
|
```
|
|
dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}}
|
|
find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True)
|
|
>>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')]
|
|
```
|
|
That is, all the functions needed (and potentially their immediate parent) so that the function to be added
|
|
in MyLayer (`foobar`) can work correctly.
|
|
"""
|
|
if initial_dependencies is None and start_entity is not None:
|
|
initial_dependencies = dependency_mapping[start_entity]
|
|
if initial_checked_dependencies is None:
|
|
initial_checked_dependencies = set()
|
|
|
|
dependency_queue = deque(initial_dependencies)
|
|
all_dependencies = set()
|
|
all_dependencies_with_parent = []
|
|
checked_dependencies = set(initial_checked_dependencies)
|
|
parents = dict.fromkeys(initial_dependencies, start_entity)
|
|
while len(dependency_queue) > 0:
|
|
# Pick element to visit
|
|
current = dependency_queue.popleft()
|
|
if current not in checked_dependencies:
|
|
# Add the dependencies
|
|
all_dependencies.add(current)
|
|
all_dependencies_with_parent += [(current, parents[current])]
|
|
if current in dependency_mapping.keys():
|
|
# Update dependency queue
|
|
dependency_queue.extend(dependency_mapping[current])
|
|
parents.update(dict.fromkeys(dependency_mapping[current], current))
|
|
# add visited node to the list
|
|
checked_dependencies.add(current)
|
|
|
|
if not return_parent:
|
|
return all_dependencies
|
|
# no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later)
|
|
return all_dependencies_with_parent
|
|
|
|
|
|
# Top-level variables that match the following patterns will always use the value in the `modular_xxx.py` file
|
|
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC", r"_HIDDEN_STATES_START_POSITION"]
|
|
|
|
# Top-level variables that match the following patterns will use the value in the `modular_xxx.py` file only if they are not None
|
|
ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE = [r"_DOCSTRING"]
|
|
|
|
|
|
class ClassDependencyMapper(CSTVisitor):
|
|
"""A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of
|
|
`global_names`.
|
|
"""
|
|
|
|
def __init__(
|
|
self, class_name: str, global_names: set[str], objects_imported_from_modeling: Optional[set[str]] = None
|
|
):
|
|
super().__init__()
|
|
self.class_name = class_name
|
|
self.dependencies = set()
|
|
self.global_names = global_names
|
|
self.objects_imported_from_modeling = (
|
|
set() if objects_imported_from_modeling is None else objects_imported_from_modeling
|
|
)
|
|
|
|
def visit_Name(self, node):
|
|
if (
|
|
node.value != self.class_name
|
|
and node.value in self.global_names
|
|
and node.value not in self.objects_imported_from_modeling
|
|
):
|
|
self.dependencies.add(node.value)
|
|
|
|
|
|
def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set:
|
|
"""Create immediate dependencies for a class node based on the `global_names`."""
|
|
temp_module = cst.Module(body=[node])
|
|
visitor = ClassDependencyMapper(node.name.value, global_names)
|
|
temp_module.visit(visitor)
|
|
return visitor.dependencies
|
|
|
|
|
|
def augmented_dependencies_for_class_node(
|
|
node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: Optional[set[str]] = None
|
|
) -> set:
|
|
"""Create augmented dependencies for a class node based on a `mapper`.
|
|
Augmented dependencies means immediate dependencies + recursive function and assignments dependencies.
|
|
"""
|
|
temp_module = cst.Module(body=[node])
|
|
visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling)
|
|
temp_module.visit(visitor)
|
|
return mapper.augment_dependencies(visitor.dependencies)
|
|
|
|
|
|
# All the potential file types to create
|
|
ALL_FILE_TYPES = (
|
|
"modeling",
|
|
"configuration",
|
|
"tokenization",
|
|
"processing",
|
|
"image_processing",
|
|
"video_processing",
|
|
"feature_extractor",
|
|
)
|
|
|
|
|
|
class ModuleMapper(CSTVisitor, ABC):
|
|
"""An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments.
|
|
Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in
|
|
`self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`).
|
|
It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the
|
|
modeling files that will be visited.
|
|
"""
|
|
|
|
METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider)
|
|
|
|
def __init__(self, python_module: cst.Module):
|
|
# fmt: off
|
|
self.python_module: cst.Module = python_module # original cst.Module being visited
|
|
self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes (it will be ordered by default!!)
|
|
self.imports = [] # stores all import statements
|
|
self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes
|
|
self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition)
|
|
self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes
|
|
self.current_function = None # this keeps track of the current module-scope function
|
|
self.current_class = None # this keeps track of the current module-scope class
|
|
self.current_assignment = None # this keeps track of the current module-scope assignment
|
|
# this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency
|
|
self.objects_imported_from_modeling = set()
|
|
# regex pattern joining every possible file type
|
|
self.match_patterns = "|".join(ALL_FILE_TYPES)
|
|
# fmt: on
|
|
|
|
def visit_ImportFrom(self, node):
|
|
"""This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have
|
|
`from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs
|
|
to be added (because it will be part of the imports)"""
|
|
import_module = self.python_module.code_for_node(node.module)
|
|
import_statement = "." * len(node.relative) + import_module
|
|
if re.search(rf"^\.({self.match_patterns})_.*", import_statement):
|
|
for imported_object in node.names:
|
|
# If an alias is present, we record it and not the original name
|
|
if imported_object.evaluated_alias is not None:
|
|
self.objects_imported_from_modeling.add(imported_object.evaluated_alias)
|
|
else:
|
|
self.objects_imported_from_modeling.add(imported_object.evaluated_name)
|
|
|
|
def visit_SimpleStatementLine(self, node):
|
|
"""
|
|
Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements
|
|
are extracted and saved in their corresponding dict. They are then used when updating dependency mappings.
|
|
"""
|
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
|
simple_top_level_assign_structure = m.SimpleStatementLine(
|
|
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
|
|
)
|
|
if m.matches(parent_node, m.Module()):
|
|
if m.matches(node, simple_top_level_assign_structure):
|
|
left_hand_side = node.body[0].targets[0].target.value
|
|
self.current_assignment = left_hand_side
|
|
self.assignments[left_hand_side] = node
|
|
elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
|
|
self.imports.append(node)
|
|
|
|
def leave_SimpleStatementLine(self, node):
|
|
# No need to check for the parent here -> everytime we exit one, it should be None anyway independently of where the
|
|
# SimpleStatement is located
|
|
self.current_assignment = None
|
|
|
|
def visit_FunctionDef(self, node):
|
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
|
if m.matches(parent_node, m.Module()):
|
|
self.current_function = node.name.value
|
|
self.functions[node.name.value] = node
|
|
|
|
def leave_FunctionDef(self, node):
|
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
|
if m.matches(parent_node, m.Module()):
|
|
self.current_function = None
|
|
|
|
def visit_If(self, node):
|
|
# If we are inside a function, do not add the import to the list of imports
|
|
if self.current_function is None and self.current_class is None:
|
|
for stmt in node.body.body:
|
|
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
|
|
self.imports.append(node)
|
|
|
|
def visit_ClassDef(self, node: ClassDef) -> None:
|
|
"""Record class nodes to create their dependencies at the end."""
|
|
self.classes[node.name.value] = node
|
|
self.current_class = node.name.value
|
|
|
|
def leave_ClassDef(self, node):
|
|
self.current_class = None
|
|
|
|
def visit_Name(self, node: cst.Call):
|
|
"""This is used to create a mapping from module-scope functions and assignments to objects used inside them."""
|
|
if self.current_function is not None:
|
|
self.object_dependency_mapping[self.current_function].add(node.value)
|
|
if self.current_assignment is not None:
|
|
self.object_dependency_mapping[self.current_assignment].add(node.value)
|
|
|
|
def leave_Module(self, node):
|
|
"""When leaving the module, we store the position of each global scoped node to allow sorting the dependencies
|
|
based on their position in the code later. We use the PositionProvider metadata wrapper for this.
|
|
We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in
|
|
`self.global_nodes`.
|
|
"""
|
|
# assign all nodes
|
|
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
|
|
# now sort the class dependency_mapping based on the position of the nodes
|
|
self.start_lines = {}
|
|
for id, node in self.global_nodes.items():
|
|
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
|
|
|
|
def _restrict_dependencies_to_known_entities(self):
|
|
"""Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that
|
|
are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc).
|
|
This should be called only after all merging operations have been finalized!!"""
|
|
global_objects = set(self.global_nodes.keys())
|
|
for object_name, dependencies in self.object_dependency_mapping.items():
|
|
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects}
|
|
|
|
def _compute_recursive_object_dependencies(self) -> dict[str, set]:
|
|
"""Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the
|
|
following file:
|
|
```
|
|
def foo():
|
|
pass
|
|
|
|
def bar():
|
|
foo()
|
|
|
|
def test():
|
|
bar()
|
|
```
|
|
this visitor can only record immediate dependencies, i.e. it will record the following
|
|
`self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create
|
|
the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`.
|
|
"""
|
|
recursive_dependencies = {}
|
|
for object_name in self.object_dependency_mapping.keys():
|
|
all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name)
|
|
recursive_dependencies[object_name] = all_dependencies
|
|
return recursive_dependencies
|
|
|
|
def augment_dependencies(self, dependencies: set[str]) -> set[str]:
|
|
"""For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and
|
|
**assignments** present in the `dependencies`.
|
|
"""
|
|
new_dependencies = dependencies.copy()
|
|
# Go through the set of dependencies
|
|
for dep in tuple(dependencies):
|
|
if dep in self.object_recursive_dependency_mapping.keys():
|
|
new_dependencies.update(self.object_recursive_dependency_mapping[dep])
|
|
return new_dependencies
|
|
|
|
def compute_class_dependencies(self):
|
|
"""For each visited class, find its dependencies based on visiting the current file + potential merged dependencies."""
|
|
self.class_dependency_mapping = {}
|
|
for class_name, class_node in self.classes.items():
|
|
dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys()))
|
|
# Correctly augment class dependencies with all needed objects
|
|
self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies)
|
|
|
|
@abstractmethod
|
|
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class ModelFileMapper(ModuleMapper):
|
|
"""A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file
|
|
in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file.
|
|
For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes
|
|
care of correctly merging dependencies, then finalizes all dependency graph computations.
|
|
Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified.
|
|
For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies
|
|
of the modeling files as well.
|
|
"""
|
|
|
|
def __init__(self, python_module: cst.Module):
|
|
super().__init__(python_module)
|
|
|
|
def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]:
|
|
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
|
|
will be created based on the modular.
|
|
"""
|
|
relative_order = {}
|
|
idx = 0
|
|
classes = sorted(
|
|
[dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x]
|
|
)
|
|
# This is because for merged dependencies, we only have relative order in the other visited file, so we need
|
|
# to track dependency order relative to a given class
|
|
if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"):
|
|
raise ValueError("Cannot correctly find the relative order of the dependencies.")
|
|
|
|
remaining_dependencies = missing_dependencies.copy()
|
|
|
|
# Start by tracking relative order class by class
|
|
for class_name in classes:
|
|
class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies)
|
|
original_dependencies = []
|
|
merged_dependencies = []
|
|
# We need to differentiate between nodes that were already present (we can get relative order globally) and
|
|
# nodes that were merged (we can get relative order only relative to the class the dependencies relate to)
|
|
for class_dep in class_dependencies:
|
|
if class_dep in self.start_lines:
|
|
original_dependencies.append(class_dep)
|
|
else:
|
|
merged_dependencies.append(class_dep)
|
|
# We need to sort deterministically before actual sorting, so that entries missing (i.e. with value 1e10)
|
|
# will always get the same order independently of the system (they come from a set, which has no deterministic order)
|
|
original_dependencies = sorted(original_dependencies, reverse=True)
|
|
# Sort both list according to the order in their respective file
|
|
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10))
|
|
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
|
|
|
|
# Add all original node first, then merged ones
|
|
for dep in original_dependencies + merged_dependencies:
|
|
remaining_dependencies.remove(dep)
|
|
relative_order[dep] = idx
|
|
idx += 1
|
|
# Add the class itself (it can sometimes already be present if the order of classes in the source file
|
|
# does not make sense, i.e. a class is used somewhere before being defined like in `rt_detr`...)
|
|
if class_name in remaining_dependencies:
|
|
remaining_dependencies.remove(class_name)
|
|
relative_order[class_name] = idx
|
|
idx += 1
|
|
|
|
# Now add what still remains
|
|
remaining_dependencies = tuple(remaining_dependencies)
|
|
original_dependencies = []
|
|
merged_dependencies = []
|
|
for dep in remaining_dependencies:
|
|
if dep in self.modular_file_start_lines:
|
|
merged_dependencies.append(dep)
|
|
else:
|
|
original_dependencies.append(dep)
|
|
# We need to sort deterministically before actual sorting, so that entries missing (i.e. with value 1e10)
|
|
# will always get the same order independently of the system (they come from a set, which has no deterministic order)
|
|
original_dependencies = sorted(original_dependencies, reverse=True)
|
|
# Sort both list according to the order in their respective file
|
|
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10))
|
|
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
|
|
|
|
# Add all original node first, then merged ones
|
|
for dep in original_dependencies + merged_dependencies:
|
|
relative_order[dep] = idx
|
|
idx += 1
|
|
|
|
return relative_order
|
|
|
|
def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
|
|
"""Update the global nodes and function dependency mapping with those from the modular file.
|
|
|
|
Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies
|
|
instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one).
|
|
"""
|
|
# Add/overwrite all needed function nodes and dependencies
|
|
self.functions.update(functions)
|
|
self.object_dependency_mapping.update(
|
|
{obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}
|
|
)
|
|
# Add them to global nodes
|
|
self.global_nodes.update(self.functions)
|
|
|
|
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
|
|
"""Update the global nodes with the assignment from the modular file.
|
|
|
|
Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it matches
|
|
a pattern in `ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE` and its value is not None, or if it matches a pattern in `ASSIGNMENTS_REGEX_TO_KEEP.
|
|
Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the big docstrings.
|
|
"""
|
|
for assignment, node in assignments.items():
|
|
should_keep = any(re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP)
|
|
|
|
should_keep_if_not_none = any(
|
|
re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE
|
|
) and not (hasattr(node.body[0].value, "value") and node.body[0].value.value == "None")
|
|
|
|
if should_keep or should_keep_if_not_none or assignment not in self.assignments:
|
|
self.assignments[assignment] = node
|
|
if assignment in object_mapping:
|
|
self.object_dependency_mapping[assignment] = object_mapping[assignment]
|
|
# Add them to global nodes
|
|
self.global_nodes.update(self.assignments)
|
|
|
|
def _merge_classes(self, classes: dict[str, cst.CSTNode]):
|
|
"""Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and
|
|
are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined
|
|
classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we
|
|
do not add the new classes to `self.classes`, but only to `global_nodes`.
|
|
"""
|
|
# Add/overwrite all needed function nodes and dependencies
|
|
self.global_nodes.update(
|
|
{
|
|
name: node
|
|
for name, node in classes.items()
|
|
if name not in self.classes and name not in self.objects_imported_from_modeling
|
|
}
|
|
)
|
|
|
|
def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines):
|
|
"""Merge classes, functions and assignments from the modular definitions into the current module file,
|
|
then record the relative order of all nodes.
|
|
Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the
|
|
merge with other files dependencies.
|
|
"""
|
|
self._merge_functions(functions, object_mapping)
|
|
self._merge_assignments(assignments, object_mapping)
|
|
self._merge_classes(classes)
|
|
self.modular_file_start_lines = start_lines
|
|
|
|
# Restrict the dependency mappings to the known entities to avoid Python's built-ins and imports
|
|
self._restrict_dependencies_to_known_entities()
|
|
# Create the global mapping of recursive dependencies for functions and assignments
|
|
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
|
|
|
@classmethod
|
|
def visit_and_merge_dependencies(
|
|
cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines
|
|
) -> "ModelFileMapper":
|
|
wrapper = MetadataWrapper(module)
|
|
mapper = cls(module)
|
|
wrapper.visit(mapper)
|
|
# Merge dependencies
|
|
mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines)
|
|
# Create the class dependencies graph
|
|
mapper.compute_class_dependencies()
|
|
return mapper
|
|
|
|
|
|
def common_partial_suffix(str1: str, str2: str) -> str:
|
|
"""Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string,
|
|
we do not consider it a common suffix and return `""`"""
|
|
common_suffix = ""
|
|
for i in range(1, min(len(str1), len(str2)) + 1):
|
|
if str1[-i] == str2[-i]:
|
|
common_suffix = str1[-i] + common_suffix
|
|
else:
|
|
break
|
|
# We do not allow full string suffix
|
|
if common_suffix == str1 or common_suffix == str2:
|
|
common_suffix = ""
|
|
return common_suffix
|
|
|
|
|
|
def replace_class_node(
|
|
mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
|
|
):
|
|
"""
|
|
Replace a class node which inherits from another modeling class. This function works in the following way:
|
|
- start from the base class node of the inherited class (a cst.Node)
|
|
- replace all methods of the base node with the methods defined in the child class
|
|
- append all new methods defined in the child class
|
|
- replace all calls to super() with the unravelled code
|
|
|
|
| ```python | | ```python
|
|
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
|
|
| def __init__(self): | | def __init__(self):
|
|
Going from: | super().__init__() | to: | super().__init__(config)
|
|
| self.dropout = 0.2 | | self.dropout = 0.2
|
|
| ``` | | self.padding_idx = config.pad_token_id
|
|
| self.vocab_size = config.vocab_size
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
| self.layers = nn.ModuleList(
|
|
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
| )
|
|
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
| self.gradient_checkpointing = False
|
|
| # Initialize weights and apply final processing
|
|
| self.post_init()
|
|
| ```
|
|
"""
|
|
all_bases = [get_full_attribute_name(k.value) for k in class_node.bases]
|
|
if any(base is None for base in all_bases):
|
|
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
|
|
|
|
original_node = mapper.classes[renamed_super_class]
|
|
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
|
|
new_name = class_node.name
|
|
|
|
# If the new class name is different from the renamed super class name, we need to update the docstrings/comments accordingly
|
|
if new_name.value != renamed_super_class:
|
|
common_suffix = common_partial_suffix(new_name.value, renamed_super_class)
|
|
# Note that this works even without common prefix, in which case it does not replace anything
|
|
old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "")
|
|
temp_module = cst.Module(body=[original_node])
|
|
original_node = temp_module.visit(
|
|
ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True)
|
|
).body[0]
|
|
|
|
# If we explicitly passed a new base with common suffix to an old base, it is for switching the prefix
|
|
# e.g. if the "natural" parent class is `PreTrainedModel` but we wanted to rename it to `PreTrainedVisionModel`
|
|
additional_bases = [base for base in all_bases if base != original_super_class]
|
|
new_bases = []
|
|
for original_base in original_node.bases:
|
|
new_base = original_base
|
|
# we only potentially switch base for Name-based bases, not Attribute
|
|
if m.matches(original_base.value, m.Name()):
|
|
original_base_name = original_base.value.value
|
|
for additional_base_name in additional_bases:
|
|
suffix = common_partial_suffix(original_base_name, additional_base_name)
|
|
if len(suffix) > 0 and suffix[0].isupper():
|
|
new_name_node = original_base.value.with_changes(value=additional_base_name)
|
|
new_base = original_base.with_changes(value=new_name_node)
|
|
break
|
|
new_bases.append(new_base)
|
|
|
|
original_methods = {
|
|
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
|
|
for f in original_node.body.body
|
|
}
|
|
updated_methods = {
|
|
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body
|
|
}
|
|
end_meth = []
|
|
|
|
assign_targets = {}
|
|
docstring_node = []
|
|
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
|
|
for func in original_node.body.body:
|
|
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
|
|
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
|
|
new_params = updated_methods[name].params
|
|
# Replace the method in the replacement class, preserving decorators
|
|
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
|
|
if kwarg_name and kwarg_name.name.value == "super_kwargs":
|
|
parent_params = {k.name.value: k for k in func.params.params}
|
|
parent_params.update({k.name.value: k for k in new_params.params[1:]})
|
|
new_params = new_params.with_changes(
|
|
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
|
|
)
|
|
# Keep decorators in `modular_xxx.py` if any, else original decorators
|
|
new_decorators = (
|
|
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
|
|
)
|
|
|
|
# Keep return annotation in `modular_xxx.py` if any, else original return annotation
|
|
new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns
|
|
|
|
if not re.match(
|
|
r"\ndef .*\(.*\):\n raise.*Error\(.*",
|
|
mapper.python_module.code_for_node(updated_methods[name]),
|
|
):
|
|
func = func.with_changes(
|
|
body=updated_methods[name].body,
|
|
params=new_params,
|
|
decorators=new_decorators,
|
|
returns=new_return_annotation,
|
|
)
|
|
else:
|
|
continue
|
|
|
|
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
|
target = mapper.python_module.code_for_node(func.body[0].targets[0])
|
|
assign_targets[target] = func
|
|
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
|
target = mapper.python_module.code_for_node(func.body[0].target)
|
|
assign_targets[target] = func
|
|
elif m.matches(func, DOCSTRING_NODE):
|
|
docstring_node = [func]
|
|
else:
|
|
end_meth.append(func)
|
|
|
|
# Port new methods that are defined only in modular-file and append at the end
|
|
for func in class_node.body.body:
|
|
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
|
|
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
|
|
# Extract the original docstring
|
|
updated_docstring = func.body[0].value.value
|
|
if len(docstring_node) == 0: # If the original docstring is empty, just create one from the updated.
|
|
docstring_node = [
|
|
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))])
|
|
]
|
|
else:
|
|
original_docstring = docstring_node[0].body[0].value.value
|
|
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
|
# Update the docstring in the original function
|
|
docstring_node = [
|
|
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
|
|
]
|
|
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
|
|
end_meth.append(func)
|
|
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
|
# TODO we only use single assign might cause issues
|
|
target = mapper.python_module.code_for_node(func.body[0].targets[0])
|
|
assign_targets[target] = func
|
|
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
|
target = mapper.python_module.code_for_node(func.body[0].target)
|
|
assign_targets[target] = func
|
|
end_meth = docstring_node + list(assign_targets.values()) + end_meth
|
|
|
|
# Replace the calls to `super()` with the unrolled code
|
|
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
|
temp_module = cst.Module(body=[result_node])
|
|
new_module = MetadataWrapper(temp_module)
|
|
new_replacement_class = new_module.visit(
|
|
SuperTransformer(temp_module, original_methods, updated_methods, all_bases)
|
|
)
|
|
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
|
|
|
# Use decorators redefined in `modular_xxx.py` if any
|
|
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
|
|
|
|
return original_node.with_changes(
|
|
body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name
|
|
)
|
|
|
|
|
|
TYPE_TO_FILE_TYPE = {
|
|
"Config": "configuration",
|
|
"Tokenizer": "tokenization",
|
|
"Processor": "processing",
|
|
"ImageProcessor": "image_processing",
|
|
"ImageProcessorFast": "image_processing*_fast", # "*" indicates where to insert the model name before the "_fast" suffix
|
|
"VideoProcessor": "video_processing",
|
|
"VideoProcessorInitKwargs": "video_processing",
|
|
"FastImageProcessorKwargs": "image_processing*_fast",
|
|
"FeatureExtractor": "feature_extractor",
|
|
"ProcessorKwargs": "processing",
|
|
"VideosKwargs": "processing",
|
|
"ImagesKwargs": "processing",
|
|
"TextKwargs": "processing",
|
|
}
|
|
|
|
|
|
def find_file_type(class_name: str) -> str:
|
|
"""Based on a class name, find the file type corresponding to the class.
|
|
If the class name is `LlamaConfig` it will return `configuration`.
|
|
The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling`
|
|
"""
|
|
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
|
|
match = re.search(rf"({match_pattern})$", class_name)
|
|
if match:
|
|
file_type = TYPE_TO_FILE_TYPE[match.group(1)]
|
|
else:
|
|
file_type = "modeling"
|
|
return file_type
|
|
|
|
|
|
# These top-level variables will always appear at the very beginning of the file, in the order they are defined in
|
|
# this list (this is to avoid having variables at weird places, even if they are not used before)
|
|
VARIABLES_AT_THE_BEGINNING = (
|
|
"logger",
|
|
"_CHECKPOINT_FOR_DOC",
|
|
"_CONFIG_FOR_DOC",
|
|
)
|
|
|
|
# These specific modeling imports should not be visited as other modeling files
|
|
IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",)
|
|
|
|
|
|
def append_new_import_node(
|
|
node: cst.CSTNode, unused_imports: set[str], added_names: set, imports_to_keep: list[cst.CSTNode]
|
|
):
|
|
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports` or `added_names`.
|
|
Also modifies `added_names` in-place accordingly."""
|
|
import_node = node.body[0]
|
|
names_to_keep = []
|
|
for name in import_node.names:
|
|
name_value = name.evaluated_alias or name.evaluated_name
|
|
if name_value not in unused_imports and name_value not in added_names:
|
|
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
|
|
added_names.add(name_value)
|
|
if len(names_to_keep) > 0:
|
|
new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)])
|
|
imports_to_keep.append(new_node)
|
|
|
|
|
|
def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]:
|
|
"""Get all the imports needed in the `body`, from the list of `all_imports`.
|
|
`body` is a dict with the following structure `{str: {"insert_idx": int, "node": cst.CSTNode}}`.
|
|
Note: we need to use `isinstance` on scope assignments, m.matches apparently does not work here yet!
|
|
"""
|
|
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
|
|
wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body))
|
|
scopes = set(wrapper.resolve(ScopeProvider).values())
|
|
unused_imports = set()
|
|
import_ref_count = defaultdict(lambda: 0)
|
|
for scope in scopes:
|
|
for assignment in scope.assignments:
|
|
node = assignment.node
|
|
if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)):
|
|
ref_count = len(assignment.references)
|
|
name = assignment.name
|
|
import_ref_count[name] = max(ref_count, import_ref_count[name])
|
|
# Similar imports may be redefined, and only used between their 1st and 2nd definition so if we already have
|
|
# a ref count > 0 at any point, the imports is actually used
|
|
unused_imports = {name for name, count in import_ref_count.items() if count <= 0 or name in body.keys()}
|
|
|
|
imports_to_keep = []
|
|
# We need to keep track of which names were already imported, because some import may be duplicated from multiple sources
|
|
# or be both protected and unprotected due to inconsistency between models
|
|
added_names = set()
|
|
existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly
|
|
for node in all_imports:
|
|
if m.matches(node, m.If()): # handle safe imports
|
|
new_statements = []
|
|
for stmt_node in node.body.body:
|
|
append_new_import_node(stmt_node, unused_imports, added_names, new_statements)
|
|
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements]
|
|
if len(new_statements) > 0:
|
|
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
|
|
imports_to_keep.append(new_node)
|
|
existing_protected_statements.update({str(stmt) for stmt in new_statements})
|
|
else:
|
|
append_new_import_node(node, unused_imports, added_names, imports_to_keep)
|
|
|
|
protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())]
|
|
usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())]
|
|
|
|
# Protected imports always appear at the end of all imports
|
|
return usual_import_nodes + protected_import_nodes
|
|
|
|
|
|
def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]:
|
|
"""Split the `__all__` assignment found in the modular between each corresponding files."""
|
|
all_all_per_file = {}
|
|
assign_node = node.body[0]
|
|
if isinstance(assign_node.value, cst.List):
|
|
# Extract the elements from the list
|
|
all_all_to_add = defaultdict(list)
|
|
for element in assign_node.value.elements:
|
|
if isinstance(element.value, cst.SimpleString):
|
|
# Remove quotes and add the string to the elements list
|
|
class_name = element.value.value
|
|
file = find_file_type(element.value.evaluated_value)
|
|
all_all_to_add[file] += [class_name]
|
|
for file, new_alls in all_all_to_add.items():
|
|
new_node = assign_node.with_changes(
|
|
value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls])
|
|
)
|
|
all_all_per_file[file] = node.with_changes(body=[new_node])
|
|
return all_all_per_file
|
|
|
|
|
|
class ModularFileMapper(ModuleMapper):
|
|
"""This is a Mapper to visit a modular file (like `modular_llama.py`). It visits the whole file, recording dependency,
|
|
then visits all imported modeling files (like `modeling_llama.py`), and manages their mutual dependencies.
|
|
Calling the method `create_modules()` after visit will create all modules based on this modular file.
|
|
"""
|
|
|
|
def __init__(self, python_module, new_name):
|
|
super().__init__(python_module)
|
|
# fmt: off
|
|
self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3`
|
|
|
|
self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"}
|
|
self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module}
|
|
|
|
self.all_all_to_add = {}
|
|
# fmt: on
|
|
|
|
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
|
"""When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it,
|
|
and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`.
|
|
"""
|
|
import_module = self.python_module.code_for_node(node.module)
|
|
import_statement = "." * len(node.relative) + import_module
|
|
if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR):
|
|
return
|
|
if m.matches(node.module, m.Attribute()):
|
|
for imported_ in node.names:
|
|
_import = re.search(
|
|
rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement
|
|
)
|
|
if _import:
|
|
source = _import.group(1)
|
|
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
|
|
raise ValueError(
|
|
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
|
|
)
|
|
if import_module not in self.model_specific_modules:
|
|
if "models" not in import_module:
|
|
import_module = "models." + import_module
|
|
if "transformers" not in import_module:
|
|
import_module = "transformers." + import_module
|
|
source_code = get_module_source_from_name(import_module)
|
|
tree = cst.parse_module(source_code)
|
|
self.model_specific_modules[import_module] = tree
|
|
imported_object = self.python_module.code_for_node(imported_.name)
|
|
self.model_specific_imported_objects[imported_object] = import_module
|
|
if m.matches(node.module, m.Name()):
|
|
if "transformers" == import_module:
|
|
raise ValueError(
|
|
f"You are importing from {import_module} directly using global imports. Import from the correct local path"
|
|
)
|
|
|
|
def visit_SimpleStatementLine(self, node):
|
|
"""If we visit an import statement not previously visited, record it. If we visit a module-scope assignment,
|
|
simply record it or, if it is `__all__`, split it between files where we should dispatch it.
|
|
"""
|
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
|
simple_top_level_assign_structure = m.SimpleStatementLine(
|
|
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
|
|
)
|
|
if m.matches(parent_node, m.Module()):
|
|
if m.matches(node, m.SimpleStatementLine(body=[m.Import()])):
|
|
self.imports.append(node)
|
|
elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])):
|
|
import_module = self.python_module.code_for_node(node.body[0].module)
|
|
import_statement = "." * len(node.body[0].relative) + import_module
|
|
if not (
|
|
re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement)
|
|
and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR)
|
|
):
|
|
self.imports.append(node)
|
|
elif m.matches(node, simple_top_level_assign_structure):
|
|
assigned_variable = node.body[0].targets[0].target.value
|
|
# __all__ is treated differently and not added to general assignments
|
|
if assigned_variable == "__all__":
|
|
self.all_all_to_add = split_all_assignment(node)
|
|
else:
|
|
self.current_assignment = assigned_variable
|
|
self.assignments[assigned_variable] = node
|
|
|
|
def leave_Module(self, node):
|
|
"""When we leave the modular file, we do the following in order:
|
|
1. for each modeling file found in the imports, rename it with the new model name, visit it, and update
|
|
its dependency graph with the new function and assignment definitions found in the modular
|
|
2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
|
|
3. compute the nested (recursive) function and assignment dependencies
|
|
"""
|
|
# Takes care of finalizing our visit
|
|
super().leave_Module(node)
|
|
|
|
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
|
|
self.visited_modules = {}
|
|
self.renamers = {}
|
|
name_prefixes = self.infer_new_model_name()
|
|
for file, module in self.model_specific_modules.items():
|
|
file_model_name = file.split(".")[-2]
|
|
new_name = name_prefixes[file]
|
|
renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name)
|
|
renamed_module = module.visit(renamer)
|
|
self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies(
|
|
renamed_module,
|
|
self.classes,
|
|
self.functions,
|
|
self.assignments,
|
|
self.object_dependency_mapping,
|
|
self.start_lines,
|
|
)
|
|
# We record it so that we can rename classes later the exact same way
|
|
self.renamers[file] = renamer
|
|
|
|
# 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
|
|
# definitions found in the visited files
|
|
self.merge_model_specific_imports(self.visited_modules)
|
|
|
|
# 3. compute the nested (recursive) function and assignment dependencies
|
|
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
|
|
|
|
# We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later
|
|
# Note that we may visit several of the same file types, thus we save them per file type, not file
|
|
self.imported_objects_per_file = defaultdict(set)
|
|
for file, mapper in self.visited_modules.items():
|
|
file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})_.*", file).group(1)
|
|
self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling)
|
|
|
|
def merge_model_specific_imports(self, visited_modules):
|
|
"""Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph,
|
|
based on the visited files."""
|
|
self.start_lines_file_mapping = {}
|
|
self.added_objects_file_mapping = {}
|
|
for object_name, file in self.model_specific_imported_objects.items():
|
|
visited_module = visited_modules[file]
|
|
self.start_lines_file_mapping[file] = visited_module.start_lines
|
|
# Add functions and their dependencies
|
|
if object_name in visited_module.functions and object_name not in self.functions:
|
|
self.functions[object_name] = visited_module.functions[object_name]
|
|
self.added_objects_file_mapping[object_name] = file
|
|
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
|
|
if dependencies is not None:
|
|
self.object_dependency_mapping[object_name] = dependencies
|
|
for dep in dependencies:
|
|
if dep not in self.global_nodes:
|
|
self.added_objects_file_mapping[dep] = file
|
|
self.functions[dep] = visited_module.global_nodes[dep]
|
|
|
|
# Add/overwrite the imported functions to other visited modules as well, in case it is absent/different
|
|
# in he modeling source file of the inherited class. See `examples/modular-tranformers/modular_switch_function.py`
|
|
# and `examples/modular-tranformers/modular_add_function.py` for examples
|
|
recursive_dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, set())
|
|
node_recursive_dependencies_mapping = {
|
|
dep: visited_module.global_nodes[dep] for dep in recursive_dependencies
|
|
}
|
|
for filename, module_mapper in self.visited_modules.items():
|
|
if filename != file:
|
|
module_mapper.global_nodes[object_name] = visited_module.functions[object_name]
|
|
if len(recursive_dependencies) > 0:
|
|
module_mapper.object_recursive_dependency_mapping[object_name] = recursive_dependencies
|
|
module_mapper.global_nodes.update(node_recursive_dependencies_mapping)
|
|
|
|
# Add assignments and their dependencies
|
|
elif object_name in visited_module.assignments and object_name not in self.assignments:
|
|
self.assignments[object_name] = visited_module.assignments[object_name]
|
|
self.added_objects_file_mapping[object_name] = file
|
|
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
|
|
if dependencies is not None:
|
|
self.object_dependency_mapping[object_name] = dependencies
|
|
for dep in dependencies:
|
|
if dep not in self.global_nodes:
|
|
self.added_objects_file_mapping[dep] = file
|
|
self.assignments[dep] = visited_module.global_nodes[dep]
|
|
|
|
# Do not forget to re-assign all nodes after the merge
|
|
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
|
|
# And restric dependencies to those nodes only
|
|
self._restrict_dependencies_to_known_entities()
|
|
|
|
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
|
|
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
|
|
will be created based on the modular.
|
|
"""
|
|
relative_order = {}
|
|
idx = 0
|
|
|
|
original_dependencies = []
|
|
other_files_dependencies = defaultdict(list)
|
|
for dep in tuple(missing_dependencies):
|
|
if dep in self.added_objects_file_mapping:
|
|
file = self.added_objects_file_mapping[dep]
|
|
other_files_dependencies[file].append(dep)
|
|
else:
|
|
original_dependencies.append(dep)
|
|
# Sort all lists according to the order in their respective file
|
|
all_dependencies = []
|
|
for file, dependencies in other_files_dependencies.items():
|
|
sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x])
|
|
all_dependencies += sorted_dependencies
|
|
all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x])
|
|
|
|
# Add all original node first, then merged ones (one file at a time)
|
|
for dep in all_dependencies:
|
|
relative_order[dep] = idx
|
|
idx += 1
|
|
|
|
return relative_order
|
|
|
|
def infer_new_model_name(self) -> dict:
|
|
"""Infer whether we are using a model name prefix different from the usual model name as defined from the filename.
|
|
This is useful e.g. when we define a new multi-modal model, and only the text part inherits from `LlamaModel`,
|
|
so we have something like:
|
|
```python
|
|
class NewModelNameTextDecoderLayer(LlamaDecoderLayer):
|
|
pass
|
|
```
|
|
with the `Text` prefix added to the model name.
|
|
However, in case of multiple prefix used, we raise a warning and use the most frequent prefix, to avoid parsing
|
|
the same file multiple times and inconsistencies in the objects added from dependencies.
|
|
If the new prefix collides with a prefix of another class in the file where we are importing from, then we also
|
|
raise a warning, and use the default prefix (model name) to avoid collisions in dependencies.
|
|
"""
|
|
prefix_model_name_mapping = defaultdict(Counter)
|
|
cased_default_name = get_cased_name(self.model_name)
|
|
# Iterate over all new classes to get modeling super classes
|
|
for class_name, class_node in self.classes.items():
|
|
modeling_bases = [
|
|
k.value.value for k in class_node.bases if k.value.value in self.model_specific_imported_objects
|
|
]
|
|
if len(modeling_bases) > 1:
|
|
raise ValueError(
|
|
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {(*modeling_bases,)}."
|
|
)
|
|
if len(modeling_bases) == 1:
|
|
filename = self.model_specific_imported_objects[modeling_bases[0]]
|
|
cased_model_name = cased_default_name # the default name prefix
|
|
suffix = common_partial_suffix(class_name, modeling_bases[0])
|
|
if len(suffix) > 0 and suffix[0].isupper():
|
|
cased_model_name = class_name.replace(suffix, "")
|
|
# If both the old model and new model share the last part of their name, is is detected as a common
|
|
# suffix, but it should not be the case -> use the full name in this case
|
|
if len(cased_model_name) < len(cased_default_name) and cased_default_name in class_name:
|
|
cased_model_name = cased_default_name
|
|
prefix_model_name_mapping[filename].update([cased_model_name])
|
|
|
|
# Check if we found multiple prefixes for some modeling files
|
|
final_name_mapping = {}
|
|
for file, prefixes_counter in prefix_model_name_mapping.items():
|
|
if len(prefixes_counter) > 1:
|
|
_, total = prefixes_counter.most_common(1)[0]
|
|
most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total]
|
|
# if the default name is in the pool of equally used prefixes, use it, otherwise last encountered
|
|
final_name = cased_default_name if cased_default_name in most_used_entities else most_used_entities[-1]
|
|
else:
|
|
final_name = list(prefixes_counter)[0]
|
|
# Check if the prefix can be used without collisions in the names
|
|
old_cased_model_name = get_cased_name(file.split(".")[-2])
|
|
old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name)
|
|
# Raise adequate warning depending on the situation
|
|
has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file)
|
|
if final_name != cased_default_name and has_prefix_collision:
|
|
if len(prefixes_counter) > 1:
|
|
logger.warning(
|
|
f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. However, the "
|
|
f"most used one, '{final_name}', is already present in the source file and will likely cause consistency "
|
|
f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args "
|
|
"and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different "
|
|
f"from '{cased_default_name}') or use a single prefix in all the modular (best)."
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"We detected the use of the new default prefix {final_name} when inheriting from {file}. However, it is "
|
|
"already present in the source file and will likely cause consistency issues. For this reason we fallback "
|
|
f"to the default prefix '{cased_default_name}' when grabbing args and dependencies. Make sure to subclass "
|
|
f"the intermediate classes with the prefix you want (if different from '{cased_default_name}')"
|
|
)
|
|
final_name = cased_default_name
|
|
elif len(prefixes_counter) > 1:
|
|
logger.warning(
|
|
f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. We will only "
|
|
f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the "
|
|
f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix "
|
|
"in all the modular (best)."
|
|
)
|
|
final_name_mapping[file] = get_lowercase_name(final_name)
|
|
|
|
# Check we are not missing imported files
|
|
for file in self.model_specific_modules.keys():
|
|
if file not in final_name_mapping.keys():
|
|
final_name_mapping[file] = self.model_name
|
|
|
|
return final_name_mapping
|
|
|
|
|
|
def check_dependencies_and_create_import_node(
|
|
file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str
|
|
) -> tuple[set[str], dict[str, cst.CSTNode]]:
|
|
"""Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case,
|
|
we need to remove it from the dependencies, and create a new import to it instead.
|
|
This scenario may appear in the following case:
|
|
If a new class in the `modular_xxx.py` file does not belong to `type_xxx.py`, but is used somewhere in `other_type_xxx.py`
|
|
(e.g. as a type hint), but none of the visited files had a similar class, then it would be imported in `type_xxx.py` as
|
|
part of the standard dependency graph (because we never encountered an import towards this new class in any file).
|
|
For example imagine the following `modular.py`:
|
|
```
|
|
from ..llama.modeling_llama import LlamaModel
|
|
|
|
class NewNameTextConfig(PretrainedConfig):
|
|
...
|
|
|
|
class NewNameConfig(PretrainedConfig):
|
|
...
|
|
|
|
class NewNameModel(LlamaModel):
|
|
config = NewNameConfig()
|
|
text_config = NewNameTextConfig()
|
|
...
|
|
```
|
|
then without the help of this function, `NewNameTextConfig` would be imported in the `modeling_newname.py` as well as
|
|
`configuration_newname.py`, because `modeling_llama.py` tells us to not import `NewNameConfig`, but has no
|
|
knowledge of `NewNameTextConfig`.
|
|
"""
|
|
class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())}
|
|
corrected_dependencies = new_dependencies.copy()
|
|
new_imports = {}
|
|
for class_name in class_dependencies:
|
|
class_file_type = find_file_type(class_name)
|
|
# In this case, we need to remove it from the dependencies and create a new import instead
|
|
if class_file_type != file_type:
|
|
corrected_dependencies.remove(class_name)
|
|
import_statement = f"from .{class_file_type}_{new_name} import {class_name}"
|
|
new_imports[class_name] = cst.parse_statement(import_statement)
|
|
|
|
return corrected_dependencies, new_imports
|
|
|
|
|
|
def get_class_node_and_dependencies(
|
|
modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict]
|
|
) -> tuple[dict, str, dict]:
|
|
"""Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new
|
|
class node based on the inherited classes if needed. Also returns any new imports of a new class defined in
|
|
the modular that we nay need.
|
|
"""
|
|
# An exception was already raised if this has len > 1
|
|
model_specific_bases = [
|
|
k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects
|
|
]
|
|
super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None
|
|
|
|
file_type = find_file_type(class_name)
|
|
file_to_update = files[file_type]
|
|
model_name = modular_mapper.model_name
|
|
|
|
# This is used to avoid adding objects to the dependencies graph if they will be imported already
|
|
imported_objects = modular_mapper.imported_objects_per_file[file_type]
|
|
|
|
# We need to replace the class node with the transformers (modeling file) super class node
|
|
if super_class is not None:
|
|
super_file_name = modular_mapper.model_specific_imported_objects[super_class]
|
|
|
|
# Get the mapper corresponding to the inherited class
|
|
mapper = modular_mapper.visited_modules[super_file_name]
|
|
# Rename the super class according to the exact same rule we used when renaming the whole module
|
|
renamer = modular_mapper.renamers[super_file_name]
|
|
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name)
|
|
|
|
# Create the new class node
|
|
updated_node = replace_class_node(mapper, node, renamed_super_class, super_class)
|
|
|
|
# Grab all immediate dependencies of the new node
|
|
new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects)
|
|
|
|
# At this point, if any class dependency is found, but belongs to another file, it means that we need to remove
|
|
# it from the dependencies, and add a new import of it instead
|
|
new_node_dependencies, new_imports = check_dependencies_and_create_import_node(
|
|
file_type, new_node_dependencies, mapper, model_name
|
|
)
|
|
|
|
# The node was modified -> look for all recursive dependencies of the new node
|
|
all_dependencies_to_add = find_all_dependencies(
|
|
dependency_mapping=mapper.class_dependency_mapping,
|
|
initial_dependencies=new_node_dependencies,
|
|
initial_checked_dependencies=set(file_to_update.keys()),
|
|
)
|
|
|
|
relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add)
|
|
nodes_to_add = {
|
|
dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add
|
|
}
|
|
|
|
# No transformers (modeling file) super class, just check functions and assignments dependencies
|
|
else:
|
|
updated_node = node
|
|
# The node was NOT modified -> no need to look recursively for other class dependencies. Indeed, even if they are not
|
|
# already defined (which would mean a weird order of the code in the modular...), they will be in the future
|
|
all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects)
|
|
|
|
# At this point, if any class dependency is found, but belongs to another file, it means that we need to remove
|
|
# it from the dependencies, and add a new import of it instead
|
|
all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node(
|
|
file_type, all_dependencies_to_add, modular_mapper, model_name
|
|
)
|
|
|
|
relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add)
|
|
nodes_to_add = {
|
|
dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep])
|
|
for dep in all_dependencies_to_add
|
|
if dep not in file_to_update.keys()
|
|
}
|
|
|
|
# Add the class node itself to the nodes to add
|
|
class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0
|
|
nodes_to_add[class_name] = (class_idx, updated_node)
|
|
|
|
return nodes_to_add, file_type, new_imports
|
|
|
|
|
|
def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]:
|
|
"""Create all the new modules based on visiting the modular file. It replaces all classes as necessary."""
|
|
files = defaultdict(dict)
|
|
current_file_indices = defaultdict(lambda: 0)
|
|
|
|
# For each class defined in modular, potentially replace the node and add it with its dependencies
|
|
for class_name, node in modular_mapper.classes.items():
|
|
nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files)
|
|
|
|
# Add the new potential new imports that we may need to the `modular_mapper` variable
|
|
modular_mapper.imported_objects_per_file[file_type].update(new_imports.keys())
|
|
modular_mapper.imports.extend(list(new_imports.values()))
|
|
|
|
# Sort the nodes according to their relative order
|
|
nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0])
|
|
# Write all nodes to file
|
|
for dependency, (_, node) in nodes_to_add:
|
|
# This is used to keep certain variables at the beginning of the file
|
|
try:
|
|
# The -1000 is arbitrary -> just keep it bigger than the list
|
|
idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency)
|
|
except ValueError:
|
|
idx = current_file_indices[file_type]
|
|
current_file_indices[file_type] += 1
|
|
files[file_type][dependency] = {"insert_idx": idx, "node": node}
|
|
|
|
# Add the __all__ statement to files at the end
|
|
for file_type, node in modular_mapper.all_all_to_add.items():
|
|
idx = current_file_indices[file_type]
|
|
files[file_type]["__all__"] = {"insert_idx": idx, "node": node}
|
|
|
|
# Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves because
|
|
# they are wrapped in SimpleStatementLine or If which could have different newlines, blanks etc)
|
|
all_imports = modular_mapper.imports.copy()
|
|
all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports}
|
|
for file, mapper in modular_mapper.visited_modules.items():
|
|
new_imports = [
|
|
node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code
|
|
]
|
|
new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports}
|
|
all_imports.extend(new_imports)
|
|
all_imports_code.update(new_imports_code)
|
|
|
|
# Find the correct imports, and write the new modules
|
|
for file, body in files.items():
|
|
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
|
|
needed_imports = get_needed_imports(body, all_imports)
|
|
full_module = needed_imports + new_body
|
|
new_module = cst.Module(body=full_module, header=modular_mapper.python_module.header)
|
|
files[file] = new_module
|
|
|
|
return files
|
|
|
|
|
|
def convert_modular_file(modular_file):
|
|
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
|
|
output = {}
|
|
if pattern is not None:
|
|
model_name = pattern.groups()[0]
|
|
# Parse the Python file
|
|
with open(modular_file, "r", encoding="utf-8") as file:
|
|
code = file.read()
|
|
module = cst.parse_module(code)
|
|
wrapper = MetadataWrapper(module)
|
|
cst_transformers = ModularFileMapper(module, model_name)
|
|
wrapper.visit(cst_transformers)
|
|
for file, module in create_modules(cst_transformers).items():
|
|
if module != {}:
|
|
# Get relative path starting from src/transformers/
|
|
relative_path = re.search(
|
|
r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/")
|
|
).group(1)
|
|
|
|
header = AUTO_GENERATED_MESSAGE.format(
|
|
relative_path=relative_path, short_name=os.path.basename(relative_path)
|
|
)
|
|
ruffed_code = run_ruff(header + module.code, True)
|
|
formatted_code = run_ruff(ruffed_code, False)
|
|
output[file] = [formatted_code, ruffed_code]
|
|
return output
|
|
else:
|
|
print(f"modular pattern not found in {modular_file}, exiting")
|
|
return {}
|
|
|
|
|
|
def save_modeling_file(modular_file, converted_file):
|
|
for file_type in converted_file.keys():
|
|
file_name_prefix = file_type.split("*")[0]
|
|
file_name_suffix = file_type.split("*")[-1] if "*" in file_type else ""
|
|
new_file_name = modular_file.replace("modular_", f"{file_name_prefix}_").replace(
|
|
".py", f"{file_name_suffix}.py"
|
|
)
|
|
non_comment_lines = len(
|
|
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
|
|
)
|
|
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
|
|
with open(new_file_name, "w", encoding="utf-8") as f:
|
|
f.write(converted_file[file_type][0])
|
|
else:
|
|
non_comment_lines = len(
|
|
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
|
|
)
|
|
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
|
|
logger.warning("The modeling code contains errors, it's written without formatting")
|
|
with open(new_file_name, "w", encoding="utf-8") as f:
|
|
f.write(converted_file[file_type][1])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--files_to_parse",
|
|
default=["all"],
|
|
nargs="+",
|
|
help="A list of `modular_xxxx` files that should be converted to single model file",
|
|
)
|
|
args = parser.parse_args()
|
|
if args.files_to_parse == ["all"]:
|
|
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
|
if args.files_to_parse == ["examples"]:
|
|
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True)
|
|
else:
|
|
for i, model_name in enumerate(args.files_to_parse):
|
|
if os.sep not in model_name:
|
|
full_path = os.path.join("src", "transformers", "models", model_name, f"modular_{model_name}.py")
|
|
# If it does not exist, try in the examples section
|
|
if not os.path.isfile(full_path):
|
|
full_path = os.path.join("examples", "modular-transformers", f"modular_{model_name}.py")
|
|
# We did not find it anywhere
|
|
if not os.path.isfile(full_path):
|
|
raise ValueError(f"Cannot find a modular file for {model_name}. Please provide the full path.")
|
|
args.files_to_parse[i] = full_path
|
|
|
|
priority_list, _ = find_priority_list(args.files_to_parse)
|
|
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted"
|
|
|
|
for file_name in priority_list:
|
|
print(f"Converting {file_name} to a single model single file format")
|
|
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
|
converted_files = convert_modular_file(file_name)
|
|
converter = save_modeling_file(file_name, converted_files)
|