transformers/utils/modular_model_converter.py
Arthur 4bff54f921
Gemma capping (#34282)
* softcapping

* soft cap before the mask

* style

* ...

* super nit

* update

* fixes

* update

* small issue with modular

* fix modular imports

* update

* fixup

* simplify a hell lot

* simplify cleaning imports

* finish fixing

* update our design

* nits

* use a deprecation cycle

* updates

* Fix modular (recursive deps need to always be computed after merges!)

* push

* fix

* update

* fix modular order

* make fix-copies

* updates

* update

* ?

* don't compile for now

* ?

* fix some stuff

* donc!

* fix copies

* update

* fixup

* ?

* fix two tests

* fix?

* for now, don't use head info

* eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :))

* fix-copies

* revert sdpa check

* Apply suggestions from code review

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* rebase, fix-copies and push

* add a slow integration test

* update the test

* fix left padding issue

* fix test

* remove duplicate scaling

* quality

* add a small test and make sure it works

* 2b

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
2024-11-19 13:52:38 +01:00

1520 lines
77 KiB
Python

# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import importlib
import os
import re
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from typing import Dict, Set
import libcst as cst
from check_copies import run_ruff
from create_dependency_mapping import find_priority_list
from libcst import ClassDef, CSTVisitor
from libcst import matchers as m
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider
from transformers import logging
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from {relative_path}.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# {short_name} file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
"""
def get_module_source_from_name(module_name: str) -> str:
# Extract the source code from the module name
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
return f"Module {module_name} not found"
with open(spec.origin, "r", encoding="utf-8") as file:
source_code = file.read()
return source_code
def preserve_case_replace(text, patterns: dict, default_name: str):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)
def replace(match):
word = match.group(0)
result = patterns.get(word, default_name)
return result
return compiled_regex.sub(replace, text)
def convert_to_camelcase(text, old_name: str, default_old_name: str):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1)
return result
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references.
It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING.
Supported renaming patterns:
- llama -> my_new_model and my_new_model -> llama
- Llama -> MyNewModel and MyNewModel -> Llama
- LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""
def __init__(
self,
old_name,
new_name,
given_old_name=None,
given_new_name=None,
):
super().__init__()
self.old_name = old_name
self.new_name = new_name
self.default_name = "".join(x.title() for x in new_name.split("_"))
if self.new_name in CONFIG_MAPPING_NAMES:
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
"Config", ""
) # the best source of truth for class names. Could also just use the ones de
self.patterns = {
old_name: new_name,
old_name.upper(): new_name.upper(),
"".join(x.title() for x in old_name.split("_")): self.default_name,
}
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
self.patterns[given_old_name] = given_new_name
if self.old_name in CONFIG_MAPPING_NAMES:
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
if self.default_old_name.isupper():
self.default_old_name = self.default_old_name.capitalize()
@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
if re.findall(r"# Copied from", updated_node.value):
return cst.RemoveFromParent()
update = preserve_case_replace(updated_node.value, self.patterns, self.default_name)
return updated_node.with_changes(value=update)
def leave_ClassDef(self, original_node, updated_node):
new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name)
return updated_node.with_changes(name=cst.Name(new_name))
DOCSTRING_NODE = m.SimpleStatementLine(
body=[
m.Expr(
value=m.SimpleString(
# match anything between """ """
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None)
)
)
]
)
def SUPER_CALL_NODE(func_name):
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
def is_call_to_super(node, func_name):
return m.matches(
node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))])
)
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
class ReplaceMethodCallTransformer(cst.CSTTransformer):
def __init__(self, all_bases: Set[str]):
self.all_bases = all_bases
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
if (
m.matches(original_node.value, m.Name())
and original_node.value.value in self.all_bases
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(
value=cst.Call(cst.Name("super")),
)
# Handle ClassB().call_to_method
elif (
m.matches(original_node.value, m.Call())
and m.matches(original_node.value.func, m.Name())
and original_node.value.func.value in self.all_bases
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
return updated_node
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if m.matches(original_node.func, m.Attribute()) and (
# Match ClassB().func_a(...)
(
m.matches(original_node.func.value, m.Call())
and m.matches(original_node.func.value.func, m.Name())
and original_node.func.value.func.value in self.all_bases
)
or
# Match ClassB.func_a(...)
(m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
# Create the new argument list without 'self'
new_args = updated_node.args[1:]
else:
new_args = updated_node.args
return updated_node.with_changes(args=new_args)
return updated_node
def get_docstring_indent(docstring):
# Match the first line after the opening triple quotes
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
if match:
# Return the indentation spaces captured
return len(match.group(1))
return 0
def merge_docstrings(original_docstring, updated_docstring):
# indent_level = get_docstring_indent(updated_docstring)
original_level = get_docstring_indent(original_docstring)
if not re.findall(r"\n\s*Args:\n", updated_docstring):
# Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 1:
updated_docstring = updated_docstring.lstrip('r"')
new_parts = updated_docstring.split("```")
if len(new_parts) != 3:
raise ValueError("There should only be one example, and it should have opening and closing '```'")
parts[1] = new_parts[1]
updated_docstring = "".join(
[
parts[0].rstrip(" \n") + new_parts[0],
f"\n{original_level*' '}```",
parts[1],
"```",
parts[2],
]
)
elif updated_docstring not in original_docstring:
# add tabulation if we are at the lowest level.
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
updated_docstring = updated_docstring.replace("\n ", "\n ")
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
return updated_docstring
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None):
self.python_module = python_module
self.original_methods = original_methods
self.updated_methods = updated_methods
self.all_assign_target = {}
self.deleted_targets = {} # child node can delete some arguments
self.all_bases = all_bases or []
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))
def update_body(self, existing_body, new_statements):
"""
Helper method to update the body by removing duplicates before adding new statements.
`existing_body` is the body of the original method, the parent class
`new_statements` are the additional statements
"""
deduplicated_new_body = []
existing_nodes = set()
for node in new_statements:
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(node.body[0].targets[0].target)
self.all_assign_target[target] = node
if m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
target = self.python_module.code_for_node(node.body[0].target)
self.deleted_targets[target] = node
for stmt in existing_body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])):
target = self.python_module.code_for_node(stmt.body[0].targets[0].target)
if target in self.deleted_targets:
logger.warning(f"Deleted the assign for {target}")
continue
if target in self.all_assign_target:
stmt = self.all_assign_target[target]
# Skip the docstring (will be added later on, at the beginning)
elif m.matches(stmt, DOCSTRING_NODE):
continue
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
deduplicated_new_body.append(stmt)
existing_nodes.add(comment_less_code)
for node in new_statements:
code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if node not in deduplicated_new_body and comment_less_code not in existing_nodes:
if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])):
deduplicated_new_body.append(node)
existing_nodes.add(comment_less_code)
deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body)
return deduplicated_new_body
def _fix_post_init_location(self, new_body: list[cst.CSTNode]):
"""Fix the location of the `post_init()` in the new body, if we added statements after the call to
`super()` (it needs to be the very last statement called)"""
# Fix the post_init() that has to be last
for i, node in enumerate(new_body):
code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if "self.post_init(" in comment_less_code and i < len(new_body) - 1:
# Remove it and add it again at the end
new_body.pop(i)
new_body.append(node)
break
return new_body
def _fix_init_location(self, new_body):
"""Fix the location of the `super().__init__()` in the new body, if we had new statements before it."""
start_index = 0
for i, node in enumerate(new_body):
if m.matches(node, DOCSTRING_NODE) and i == start_index:
start_index += 1
continue
code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if "super().__init__" in comment_less_code and i > start_index:
# Remove it and add it again at the top after the docstrings
node = new_body.pop(i)
new_body = new_body[:start_index] + [node] + new_body[start_index:]
break
return new_body
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
"""Updates the body of the input `node`'s `func_name` function by replacing calls
to super().func_name() with the source code of the parent class' `func_name`.
It keeps everything that is defined before `super().func_name()`.
"""
self.has_docstring = False
parent_has_docstring = False
if func_name in self.original_methods:
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
new_body = []
has_super_call = False
for i, expr in enumerate(node.body):
if is_call_to_super(expr, func_name):
has_super_call = True
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :]))
new_body = self._fix_init_location(new_body)
else:
expr = expr.visit(self.transformer)
if m.matches(expr, DOCSTRING_NODE):
self.has_docstring = True
if parent_has_docstring: # actually here we ought to de-duplicate?
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
updated_docstring = expr.body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
else:
new_node = [expr]
new_body.extend(new_node)
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
new_body.append(expr)
if not self.has_docstring and parent_has_docstring:
new_body = [self.original_methods[func_name].body.body[0]] + new_body
return node.with_changes(body=new_body)
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
if updated_node.name.value in self.updated_methods:
name = updated_node.name.value
new_body = self.replace_super_calls(updated_node.body, name)
return updated_node.with_changes(body=new_body, params=updated_node.params)
return updated_node
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
""" "When a return statement is reached, it is replaced with the unrolled super code"""
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
func_def = self.get_metadata(ParentNodeProvider, original_node)
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
updated_return_value = updated_node.value.with_changes(
args=[
cst.Arg(
value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))])
)
]
)
return updated_node.with_changes(value=updated_return_value)
return updated_node
def find_all_dependencies(
dependency_mapping: Dict[str, set],
start_entity: str | None = None,
initial_dependencies: set | None = None,
initial_checked_dependencies: set | None = None,
return_parent: bool = False,
) -> list | set:
"""Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of
BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`.
Args:
dependency_mapping (`Dict[str, set]`):
A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names,
a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called
in `foo`'s definition.
start_entity (str | None, *optional*):
A key of `dependency_mapping`, indicating from which entity to start the search.
initial_dependencies (set | None, *optional*):
If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue
from all the entities in `initial_dependencies`, if they are in `dependency_mapping`.
initial_checked_dependencies (set | None, *optional*):
If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies.
return_parent (bool, *optional*):
If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note
that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs.
Returns:
A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`.
Example:
Given the following structure in the `modular_xxx.py` file:
```
def foo1():
pass
def foo2():
pass
def bar():
foo1()
def foobar():
bar()
foo2()
class MyLayer(SomeOtherModelLayer):
def forward(...):
foobar()
```
and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get:
```
dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}}
find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True)
>>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')]
```
That is, all the functions needed (and potentially their immediate parent) so that the function to be added
in MyLayer (`foobar`) can work correctly.
"""
if initial_dependencies is None and start_entity is not None:
initial_dependencies = dependency_mapping[start_entity]
if initial_checked_dependencies is None:
initial_checked_dependencies = set()
dependency_queue = deque(initial_dependencies)
all_dependencies = set()
all_dependencies_with_parent = []
checked_dependencies = set(initial_checked_dependencies)
parents = {initial_dep: start_entity for initial_dep in initial_dependencies}
while len(dependency_queue) > 0:
# Pick element to visit
current = dependency_queue.popleft()
if current not in checked_dependencies:
# Add the dependencies
all_dependencies.add(current)
all_dependencies_with_parent += [(current, parents[current])]
if current in dependency_mapping.keys():
# Update dependency queue
dependency_queue.extend(dependency_mapping[current])
parents.update({dep: current for dep in dependency_mapping[current]})
# add visited node to the list
checked_dependencies.add(current)
if not return_parent:
return all_dependencies
# no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later)
return all_dependencies_with_parent
# These top-level variables will always use the value in the `modular_xxx.py` file
ASSIGNMENTS_TO_KEEP = {
"_CHECKPOINT_FOR_DOC",
}
class ClassDependencyMapper(CSTVisitor):
"""A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of
`global_names`.
"""
def __init__(
self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None
):
super().__init__()
self.class_name = class_name
self.dependencies = set()
self.global_names = global_names
self.objects_imported_from_modeling = (
set() if objects_imported_from_modeling is None else objects_imported_from_modeling
)
def visit_Name(self, node):
if (
node.value != self.class_name
and node.value in self.global_names
and node.value not in self.objects_imported_from_modeling
):
self.dependencies.add(node.value)
def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set:
"""Create immediate dependencies for a class node based on the `global_names`."""
temp_module = cst.Module(body=[node])
visitor = ClassDependencyMapper(node.name.value, global_names)
temp_module.visit(visitor)
return visitor.dependencies
def augmented_dependencies_for_class_node(
node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None
) -> set:
"""Create augmented dependencies for a class node based on a `mapper`.
Augmented dependencies means immediate dependencies + recursive function and assignments dependencies.
"""
temp_module = cst.Module(body=[node])
visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling)
temp_module.visit(visitor)
return mapper.augment_dependencies(visitor.dependencies)
# All the potential file types to create
ALL_FILE_TYPES = (
"modeling",
"configuration",
"tokenization",
"processing",
"image_processing",
"feature_extractor",
)
class ModuleMapper(CSTVisitor, ABC):
"""An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments.
Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in
`self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`).
It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the
modeling files that will be visited.
"""
METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider)
def __init__(self, python_module: cst.Module):
# fmt: off
self.python_module: cst.Module = python_module # original cst.Module being visited
self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes (it will be ordered by default!!)
self.imports = [] # stores all import statements
self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes
self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition)
self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes
self.current_function = None # this keeps track of the current module-scope function
self.current_assignment = None # this keeps track of the current module-scope assignment
# this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency
self.objects_imported_from_modeling = set()
# regex pattern joining every possible file type
self.match_patterns = "|".join(ALL_FILE_TYPES)
# fmt: on
def visit_ImportFrom(self, node):
"""This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have
`from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs
to be added (because it will be part of the imports)"""
import_module = self.python_module.code_for_node(node.module)
import_statement = "." * len(node.relative) + import_module
if re.search(rf"^\.({self.match_patterns})_.*", import_statement):
for imported_object in node.names:
# If an alias is present, we record it and not the original name
if imported_object.evaluated_alias is not None:
self.objects_imported_from_modeling.add(imported_object.evaluated_alias)
else:
self.objects_imported_from_modeling.add(imported_object.evaluated_name)
def visit_SimpleStatementLine(self, node):
"""
Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements
are extracted and saved in their corresponding dict. They are then used when updating dependency mappings.
"""
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
simple_top_level_assign_structure = m.SimpleStatementLine(
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
)
if m.matches(parent_node, m.Module()):
if m.matches(node, simple_top_level_assign_structure):
left_hand_side = node.body[0].targets[0].target.value
self.current_assignment = left_hand_side
self.assignments[left_hand_side] = node
elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
self.imports.append(node)
def leave_SimpleStatementLine(self, node):
# No need to check for the parent here -> everytime we exit one, it should be None anyway independently of where the
# SimpleStatement is located
self.current_assignment = None
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.current_function = node.name.value
self.functions[node.name.value] = node
def leave_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.current_function = None
def visit_If(self, node):
for stmt in node.body.body:
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])):
self.imports.append(node)
def visit_ClassDef(self, node: ClassDef) -> None:
"""Record class nodes to create their dependencies at the end."""
self.classes[node.name.value] = node
def visit_Name(self, node: cst.Call):
"""This is used to create a mapping from module-scope functions and assignments to objects used inside them."""
if self.current_function is not None:
self.object_dependency_mapping[self.current_function].add(node.value)
if self.current_assignment is not None:
self.object_dependency_mapping[self.current_assignment].add(node.value)
def leave_Module(self, node):
"""When leaving the module, we store the position of each global scoped node to allow sorting the dependencies
based on their position in the code later. We use the PositionProvider metadata wrapper for this.
We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in
`self.global_nodes`.
"""
# assign all nodes
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
# now sort the class dependency_mapping based on the position of the nodes
self.start_lines = {}
for id, node in self.global_nodes.items():
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line
def _restrict_dependencies_to_known_entities(self):
"""Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that
are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc).
This should be called only after all merging operations have been finalized!!"""
global_objects = set(self.global_nodes.keys())
for object_name, dependencies in self.object_dependency_mapping.items():
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects}
def _compute_recursive_object_dependencies(self) -> dict[str, set]:
"""Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the
following file:
```
def foo():
pass
def bar():
foo()
def test():
bar()
```
this visitor can only record immediate dependencies, i.e. it will record the following
`self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create
the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`.
"""
recursive_dependencies = {}
for object_name in self.object_dependency_mapping.keys():
all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name)
recursive_dependencies[object_name] = all_dependencies
return recursive_dependencies
def augment_dependencies(self, dependencies: set[str]) -> set[str]:
"""For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and
**assignments** present in the `dependencies`.
"""
new_dependencies = dependencies.copy()
# Go through the set of dependencies
for dep in tuple(dependencies):
if dep in self.object_recursive_dependency_mapping.keys():
new_dependencies.update(self.object_recursive_dependency_mapping[dep])
return new_dependencies
def compute_class_dependencies(self):
"""For each visited class, find its dependencies based on visiting the current file + potential merged dependencies."""
self.class_dependency_mapping = {}
for class_name, class_node in self.classes.items():
dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys()))
# Correctly augment class dependencies with all needed objects
self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies)
@abstractmethod
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
raise NotImplementedError
class ModelFileMapper(ModuleMapper):
"""A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file
in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file.
For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes
care of correctly merging dependencies, then finalizes all dependency graph computations.
Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified.
For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies
of the modeling files as well.
"""
def __init__(self, python_module: cst.Module):
super().__init__(python_module)
def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]:
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
will be created based on the modular.
"""
relative_order = {}
idx = 0
classes = sorted(
[dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x]
)
# This is because for merged dependencies, we only have relative order in the other visited file, so we need
# to track dependency order relative to a given class
if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"):
raise ValueError("Cannot correctly find the relative order of the dependencies.")
remaining_dependencies = missing_dependencies.copy()
# Start by tracking relative order class by class
for class_name in classes:
class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies)
original_dependencies = []
merged_dependencies = []
# We need to differentiate between nodes that were already present (we can get relative order globally) and
# nodes that were merged (we can get relative order only relative to the class the dependencies relate to)
for class_dep in class_dependencies:
if class_dep in self.start_lines:
original_dependencies.append(class_dep)
else:
merged_dependencies.append(class_dep)
# Sort both list according to the order in their respective file
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
# Add all original node first, then merged ones
for dep in original_dependencies + merged_dependencies:
remaining_dependencies.remove(dep)
relative_order[dep] = idx
idx += 1
# Add the class itself
remaining_dependencies.remove(class_name)
relative_order[class_name] = idx
idx += 1
# Now add what still remains
remaining_dependencies = tuple(remaining_dependencies)
original_dependencies = []
merged_dependencies = []
for dep in remaining_dependencies:
if dep in self.modular_file_start_lines:
merged_dependencies.append(dep)
else:
original_dependencies.append(dep)
# Sort both list according to the order in their respective file
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x])
# Add all original node first, then merged ones
for dep in original_dependencies + merged_dependencies:
relative_order[dep] = idx
idx += 1
return relative_order
def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
"""Update the global nodes and function dependency mapping with those from the modular file.
Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies
instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one).
"""
# Add/overwrite all needed function nodes and dependencies
self.functions.update(functions)
self.object_dependency_mapping.update(
{obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()}
)
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
"""Update the global nodes with the assignment from the modular file.
Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it is
in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the
big docstrings.
"""
for assignment, node in assignments.items():
if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments:
self.assignments[assignment] = node
if assignment in object_mapping:
self.object_dependency_mapping[assignment] = object_mapping[assignment]
def _merge_classes(self, classes: dict[str, cst.CSTNode]):
"""Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and
are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined
classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we
do not add the new classes to `self.classes`, but only to `global_nodes`.
"""
# Add/overwrite all needed function nodes and dependencies
self.global_nodes.update(
{
name: node
for name, node in classes.items()
if name not in self.classes and name not in self.objects_imported_from_modeling
}
)
def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines):
"""Merge classes, functions and assignments from the modular definitions into the current module file,
then record the relative order of all nodes.
Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the
merge with other files dependencies.
"""
self._merge_functions(functions, object_mapping)
self._merge_assignments(assignments, object_mapping)
self._merge_classes(classes)
self.modular_file_start_lines = start_lines
# Correctly re-set the global nodes at this point
self.global_nodes.update(self.functions)
self.global_nodes.update(self.assignments)
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
self._restrict_dependencies_to_known_entities()
# Create the global mapping of recursive dependencies for functions and assignments
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
@classmethod
def visit_and_merge_dependencies(
cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines
) -> "ModelFileMapper":
wrapper = MetadataWrapper(module)
mapper = cls(module)
wrapper.visit(mapper)
# Merge dependencies
mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines)
# Create the class dependencies graph
mapper.compute_class_dependencies()
return mapper
def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str):
"""
Replace a class node which inherits from another modeling class. This function works in the following way:
- start from the base class node of the inherited class (a cst.Node)
- replace all methods of the base node with the methods defined in the child class
- append all new methods defined in the child class
- replace all calls to super() with the unravelled code
| ```python | | ```python
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
| def __init__(self): | | def __init__(self):
Going from: | super().__init__() | to: | super().__init__(config)
| self.dropout = 0.2 | | self.dropout = 0.2
| ``` | | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
| self.layers = nn.ModuleList(
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
| )
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
| self.gradient_checkpointing = False
| # Initialize weights and apply final processing
| self.post_init()
| ```
"""
all_bases = [k.value.value for k in class_node.bases]
original_node = mapper.classes[renamed_super_class]
original_methods = {
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
for f in original_node.body.body
}
updated_methods = {
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body
}
end_meth = []
assign_targets = {}
docstring_node = []
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
for func in original_node.body.body:
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
new_params = updated_methods[name].params
# Replace the method in the replacement class, preserving decorators
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
if kwarg_name and kwarg_name.name.value == "super_kwargs":
parent_params = {k.name.value: k for k in func.params.params}
parent_params.update({k.name.value: k for k in new_params.params[1:]})
new_params = new_params.with_changes(
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
)
# Keep decorators in `modular_xxx.py` if any, else original decorators
new_decorators = (
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
)
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
mapper.python_module.code_for_node(updated_methods[name]),
):
func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators)
else:
continue
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
target = mapper.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = mapper.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
elif m.matches(func, DOCSTRING_NODE):
docstring_node = [func]
else:
end_meth.append(func)
# Port new methods that are defined only in modular-file and append at the end
for func in class_node.body.body:
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
# Extract the original docstring
updated_docstring = func.body[0].value.value
original_docstring = docstring_node[0].body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
# Update the docstring in the original function
docstring_node = [
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
]
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
end_meth.append(func)
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
# TODO we only use single assign might cause issues
target = mapper.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = mapper.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
end_meth = docstring_node + list(assign_targets.values()) + end_meth
# Replace the calls to `super()` with the unrolled code
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(
SuperTransformer(temp_module, original_methods, updated_methods, all_bases)
)
new_replacement_body = new_replacement_class.body[0].body # get the indented block
# Use decorators redefined in `modular_xxx.py` if any
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
name = class_node.name
return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name)
TYPE_TO_FILE_TYPE = {
"Config": "configuration",
"Tokenizer": "tokenization",
"Processor": "processing",
"ImageProcessor": "image_processing",
"FeatureExtractor": "feature_extractor",
"ProcessorKwargs": "processing",
"ImagesKwargs": "processing",
"TextKwargs": "processing",
}
def find_file_type(class_name: str) -> str:
"""Based on a class name, find the file type corresponding to the class.
If the class name is `LlamaConfig` it will return `configuration`.
The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling`
"""
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
match = re.search(rf"({match_pattern})$", class_name)
if match:
file_type = TYPE_TO_FILE_TYPE[match.group(1)]
else:
file_type = "modeling"
return file_type
# These top-level variables will always appear at the very beginning of the file, in the order they are defined in
# this list (this is to avoid having variables at weird places, even if they are not used before)
VARIABLES_AT_THE_BEGINNING = (
"logger",
"_CHECKPOINT_FOR_DOC",
"_CONFIG_FOR_DOC",
)
# These specific modeling imports should not be visited as other modeling files
IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",)
def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: list[cst.CSTNode]):
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`."""
import_node = node.body[0]
names_to_keep = []
for name in import_node.names:
name_value = name.evaluated_name
if name_value not in unused_imports:
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
if len(names_to_keep) > 0:
new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)])
imports_to_keep.append(new_node)
def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]:
"""Get all the imports needed in the `body`, from the list of `all_imports`.
`body` is a dict with the following structure `{str: {"insert_idx": int, "node": cst.CSTNode}}`.
Note: we need to use `isinstance` on scope assignements, m.matches apparently does not work here yet!
"""
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body))
scopes = set(wrapper.resolve(ScopeProvider).values())
unused_imports = set()
import_ref_count = {}
for scope in scopes:
for assignment in scope.assignments:
node = assignment.node
if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)):
ref_count = len(assignment.references)
name = assignment.name
# Similar imports may be redefined, and only used between their 1st and 2nd definition
# so if we already have a ref count > 0, the imports is actually used
if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys():
unused_imports.add(name)
import_ref_count[name] = ref_count
imports_to_keep = []
for node in all_imports:
if m.matches(node, m.If()): # handle safe imports
new_statements = []
for stmt_node in node.body.body:
append_new_import_node(stmt_node, unused_imports, new_statements)
if len(new_statements) > 0:
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
imports_to_keep.append(new_node)
else:
append_new_import_node(node, unused_imports, imports_to_keep)
protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())]
usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())]
# If the same import is both protected and unprotected, only keep the protected one
for protected_node in protected_import_nodes:
for stmt_node in protected_node.body.body:
usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]]
# Protected imports always appear at the end of all imports
return usual_import_nodes + protected_import_nodes
def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]:
"""Split the `__all__` assignment found in the modular between each corresponding files."""
all_all_per_file = {}
assign_node = node.body[0]
if isinstance(assign_node.value, cst.List):
# Extract the elements from the list
all_all_to_add = defaultdict(list)
for element in assign_node.value.elements:
if isinstance(element.value, cst.SimpleString):
# Remove quotes and add the string to the elements list
class_name = element.value.value
file = find_file_type(element.value.evaluated_value)
all_all_to_add[file] += [class_name]
for file, new_alls in all_all_to_add.items():
new_node = assign_node.with_changes(
value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls])
)
all_all_per_file[file] = node.with_changes(body=[new_node])
return all_all_per_file
class ModularFileMapper(ModuleMapper):
"""This is a Mapper to visit a modular file (like `modular_llama.py`). It visits the whole file, recording dependency,
then visits all imported modeling files (like `modeling_llama.py`), and manages their mutual dependencies.
Calling the method `create_modules()` after visit will create all modules based on this modular file.
"""
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
super().__init__(python_module)
# fmt: off
self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3`
self.given_old_name = given_old_name
self.given_new_name = given_new_name
self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"}
self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module}
self.all_all_to_add = {}
# fmt: on
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it,
and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`.
"""
import_module = self.python_module.code_for_node(node.module)
import_statement = "." * len(node.relative) + import_module
if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR):
return
if m.matches(node.module, m.Attribute()):
for imported_ in node.names:
_import = re.search(
rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement
)
if _import:
source = _import.group(1)
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_):
raise ValueError(
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead"
)
if import_module not in self.model_specific_modules:
if "models" not in import_module:
import_module = "models." + import_module
if "transformers" not in import_module:
import_module = "transformers." + import_module
source_code = get_module_source_from_name(import_module)
tree = cst.parse_module(source_code)
self.model_specific_modules[import_module] = tree
imported_object = self.python_module.code_for_node(imported_.name)
self.model_specific_imported_objects[imported_object] = import_module
if m.matches(node.module, m.Name()):
if "transformers" == import_module:
raise ValueError(
f"You are importing from {import_module} directly using global imports. Import from the correct local path"
)
def visit_SimpleStatementLine(self, node):
"""If we visit an import statement not previously visited, record it. If we visit a module-scope assignment,
simply record it or, if it is `__all__`, split it between files where we should dispatch it.
"""
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
simple_top_level_assign_structure = m.SimpleStatementLine(
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
)
if m.matches(parent_node, m.Module()):
if m.matches(node, m.SimpleStatementLine(body=[m.Import()])):
self.imports.append(node)
elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])):
import_module = self.python_module.code_for_node(node.body[0].module)
import_statement = "." * len(node.body[0].relative) + import_module
if not (
re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement)
and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR)
):
self.imports.append(node)
elif m.matches(node, simple_top_level_assign_structure):
assigned_variable = node.body[0].targets[0].target.value
# __all__ is treated differently and not added to general assignments
if assigned_variable == "__all__":
self.all_all_to_add = split_all_assignment(node)
else:
self.current_assignment = assigned_variable
self.assignments[assigned_variable] = node
def leave_Module(self, node):
"""When we leave the modular file, we do the following in order:
1. for each modeling file found in the imports, rename it with the new model name, visit it, and update
its dependency graph with the new function and assignment definitions found in the modular
2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
3. compute the nested (recursive) function and assignment dependencies
"""
# Takes care of finalizing our visit
super().leave_Module(node)
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
self.visited_modules = {}
self.renamers = {}
for file, module in self.model_specific_modules.items():
file_model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", file).groups()[0]
renamer = ReplaceNameTransformer(
file_model_name, self.model_name, self.given_old_name, self.given_new_name
)
renamed_module = module.visit(renamer)
self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies(
renamed_module,
self.classes,
self.functions,
self.assignments,
self.object_dependency_mapping,
self.start_lines,
)
# We record it so that we can rename classes later the exact same way
self.renamers[file] = renamer
# 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# definitions found in the visited files
self.merge_model_specific_imports(self.visited_modules)
# 3. compute the nested (recursive) function and assignment dependencies
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()
# We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later
# Note that we may visit several of the same file types, thus we save them per file type, not file
self.imported_objects_per_file = defaultdict(set)
for file, mapper in self.visited_modules.items():
file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})_.*", file).group(1)
self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling)
def merge_model_specific_imports(self, visited_modules):
"""Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph,
based on the visited files."""
self.start_lines_file_mapping = {}
self.added_objects_file_mapping = {}
for object_name, file in self.model_specific_imported_objects.items():
visited_module = visited_modules[file]
self.start_lines_file_mapping[file] = visited_module.start_lines
# Add functions and their dependencies
if object_name in visited_module.functions and object_name not in self.functions:
self.functions[object_name] = visited_module.functions[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
self.functions[dep] = visited_module.global_nodes[dep]
# Add assignments and their dependencies
elif object_name in visited_module.assignments and object_name not in self.assignments:
self.assignments[object_name] = visited_module.assignments[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
self.assignments[dep] = visited_module.global_nodes[dep]
# Do not forget to re-assign all nodes after the merge
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
# And restric dependencies to those nodes only
self._restrict_dependencies_to_known_entities()
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
will be created based on the modular.
"""
relative_order = {}
idx = 0
original_dependencies = []
other_files_dependencies = defaultdict(list)
for dep in tuple(missing_dependencies):
if dep in self.added_objects_file_mapping:
file = self.added_objects_file_mapping[dep]
other_files_dependencies[file].append(dep)
else:
original_dependencies.append(dep)
# Sort all lists according to the order in their respective file
all_dependencies = []
for file, dependencies in other_files_dependencies.items():
sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x])
all_dependencies += sorted_dependencies
all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x])
# Add all original node first, then merged ones (one file at a time)
for dep in all_dependencies:
relative_order[dep] = idx
idx += 1
return relative_order
def check_dependencies_and_create_import_node(
file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str
) -> tuple[set[str], dict[str, cst.CSTNode]]:
"""Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case,
we need to remove it from the dependencies, and create a new import to it instead.
This scenario may appear in the following case:
If a new class in the `modular_xxx.py` file does not belong to `type_xxx.py`, but is used somewhere in `other_type_xxx.py`
(e.g. as a type hint), but none of the visited files had a similar class, then it would be imported in `type_xxx.py` as
part of the standard dependency graph (because we never encountered an import towards this new class in any file).
For example imagine the following `modular.py`:
```
from ..llama.modeling_llama import LlamaModel
class NewNameTextConfig(PretrainedConfig):
...
class NewNameConfig(PretrainedConfig):
...
class NewNameModel(LlamaModel):
config = NewNameConfig()
text_config = NewNameTextConfig()
...
```
then without the help of this function, `NewNameTextConfig` would be imported in the `modeling_newname.py` as well as
`configuration_newname.py`, because `modeling_llama.py` tells us to not import `NewNameConfig`, but has no
knowledge of `NewNameTextConfig`.
"""
class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())}
corrected_dependencies = new_dependencies.copy()
new_imports = {}
for class_name in class_dependencies:
class_file_type = find_file_type(class_name)
# In this case, we need to remove it from the dependencies and create a new import instead
if class_file_type != file_type:
corrected_dependencies.remove(class_name)
import_statement = f"from .{class_file_type}_{new_name} import {class_name}"
new_imports[class_name] = cst.parse_statement(import_statement)
return corrected_dependencies, new_imports
def get_class_node_and_dependencies(
modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict]
) -> tuple[dict, str, dict]:
"""Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new
class node based on the inherited classes if needed. Also returns any new imports of a new class defined in
the modular that we nay need.
"""
bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects]
if len(bases) > 1:
raise ValueError(
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}."
)
file_type = find_file_type(class_name)
file_to_update = files[file_type]
model_name = modular_mapper.model_name
# This is used to avoid adding objects to the dependencies graph if they will be imported already
imported_objects = modular_mapper.imported_objects_per_file[file_type]
# We need to replace the class node with the transformers (modeling file) super class node
if len(bases) == 1:
super_class = bases[0]
super_file_name = modular_mapper.model_specific_imported_objects[super_class]
# Get the mapper corresponding to the inherited class
mapper = modular_mapper.visited_modules[super_file_name]
# Rename the super class according to the exact same rule we used when renaming the whole module
renamer = modular_mapper.renamers[super_file_name]
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name)
renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name)
# Create the new class node
updated_node = replace_class_node(mapper, node, renamed_super_class)
# Grab all immediate dependencies of the new node
new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects)
# At this point, if any class dependency is found, but belongs to another file, it means that we need to remove
# it from the dependencies, and add a new import of it instead
new_node_dependencies, new_imports = check_dependencies_and_create_import_node(
file_type, new_node_dependencies, mapper, model_name
)
# The node was modified -> look for all recursive dependencies of the new node
all_dependencies_to_add = find_all_dependencies(
dependency_mapping=mapper.class_dependency_mapping,
initial_dependencies=new_node_dependencies,
initial_checked_dependencies=set(file_to_update.keys()),
)
relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add)
nodes_to_add = {
dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add
}
# No transformers (modeling file) super class, just check functions and assignments dependencies
else:
updated_node = node
# The node was NOT modified -> no need to look recursively for other class dependencies. Indeed, even if they are not
# already defined (which would mean a weird order of the code in the modular...), they will be in the future
all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects)
# At this point, if any class dependency is found, but belongs to another file, it means that we need to remove
# it from the dependencies, and add a new import of it instead
all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node(
file_type, all_dependencies_to_add, modular_mapper, model_name
)
relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add)
nodes_to_add = {
dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep])
for dep in all_dependencies_to_add
if dep not in file_to_update.keys()
}
# Add the class node itself to the nodes to add
class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0
nodes_to_add[class_name] = (class_idx, updated_node)
return nodes_to_add, file_type, new_imports
def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]:
"""Create all the new modules based on visiting the modular file. It replaces all classes as necesary."""
files = defaultdict(dict)
current_file_indices = defaultdict(lambda: 0)
# For each class defined in modular, potentially replace the node and add it with its dependencies
for class_name, node in modular_mapper.classes.items():
nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files)
# Add the new potential new imports that we may need to the `modular_mapper` variable
modular_mapper.imported_objects_per_file[file_type].update(new_imports.keys())
modular_mapper.imports.extend(list(new_imports.values()))
# Sort the nodes according to their relative order
nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0])
# Write all nodes to file
for dependency, (_, node) in nodes_to_add:
# This is used to keep certain variables at the beginning of the file
try:
# The -1000 is arbitrary -> just keep it bigger than the list
idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency)
except ValueError:
idx = current_file_indices[file_type]
current_file_indices[file_type] += 1
files[file_type][dependency] = {"insert_idx": idx, "node": node}
# Add the __all__ statement to files at the end
for file_type, node in modular_mapper.all_all_to_add.items():
idx = current_file_indices[file_type]
files[file_type]["__all__"] = {"insert_idx": idx, "node": node}
# Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves because
# they are wrapped in SimpleStatementLine or If which could have different newlines, blanks etc)
all_imports = modular_mapper.imports.copy()
all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports}
for file, mapper in modular_mapper.visited_modules.items():
new_imports = [
node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code
]
new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports}
all_imports.extend(new_imports)
all_imports_code.update(new_imports_code)
# Find the correct imports, and write the new modules
for file, body in files.items():
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
needed_imports = get_needed_imports(body, all_imports)
full_module = needed_imports + new_body
new_module = cst.Module(body=full_module, header=modular_mapper.python_module.header)
files[file] = new_module
return files
def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None):
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
output = {}
if pattern is not None:
model_name = pattern.groups()[0]
# Parse the Python file
with open(modular_file, "r", encoding="utf-8") as file:
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
if cst_transformers is None:
cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name)
wrapper.visit(cst_transformers)
for file, module in create_modules(cst_transformers).items():
if module != {}:
# Get relative path starting from src/transformers/
relative_path = re.search(
r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/")
).group(1)
header = AUTO_GENERATED_MESSAGE.format(
relative_path=relative_path, short_name=os.path.basename(relative_path)
)
ruffed_code = run_ruff(header + module.code, True)
formatted_code = run_ruff(ruffed_code, False)
output[file] = [formatted_code, ruffed_code]
return output
else:
print(f"modular pattern not found in {modular_file}, exiting")
return {}
def save_modeling_file(modular_file, converted_file):
for file_type in converted_file.keys():
non_comment_lines = len(
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
)
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
f.write(converted_file[file_type][0])
else:
non_comment_lines = len(
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
)
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
logger.warning("The modeling code contains errors, it's written without formatting")
with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
f.write(converted_file[file_type][1])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/gemma2/modular_gemma2.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)
parser.add_argument(
"--old_model_name",
required=False,
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file",
)
parser.add_argument(
"--new_model_name",
required=False,
help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file",
)
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
args.files_to_parse += glob.glob("examples/**/modular_*.py", recursive=True)
for file_name in find_priority_list(args.files_to_parse):
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
converter = save_modeling_file(file_name, converted_files)