Improve modular converter (#33991)

* improve modular

* style

* Update modular_model_converter.py

* pretty print warning

* style

* Support to remove unused classes as part of added dependencies as well

* nits

* correct bug

* add example

* style

* Add documentation
This commit is contained in:
Cyril Vallez 2024-10-08 14:53:58 +02:00 committed by GitHub
parent fb360a6c7a
commit 17806d11ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 299 additions and 71 deletions

View File

@ -52,6 +52,7 @@ For example:
reference it (in case of addition) or completely remove it (in case of deletion). reference it (in case of addition) or completely remove it (in case of deletion).
- If a class inherits from another, for example: class GemmaModel(LlamaModel):, dependencies are automatically - If a class inherits from another, for example: class GemmaModel(LlamaModel):, dependencies are automatically
inferred. All submodules will be automatically inferred from the superclass. inferred. All submodules will be automatically inferred from the superclass.
- If you define new functions in the `modular` and use them inside classes, the linter will automatically infer the
You should be able to write everything (the tokenizer, the image processor, the model, the config) in this `modular` You should be able to write everything (the tokenizer, the image processor, the model, the config) in this `modular`
file, and the corresponding files will be created for you. file, and the corresponding files will be created for you.
@ -158,6 +159,25 @@ class GemmaTokenizer(LlamaTokenizer):
raise AttributeError("Not needed for Gemma") raise AttributeError("Not needed for Gemma")
``` ```
### Define new functions
If you define a new function in the `modular` file to be used inside a class, say
```python
def my_new_function(*args, **kwargs):
# Do something here
pass
class GemmaModel(LlamaModel):
def forward(*args, **kwargs):
# Call the function
example = my_new_function(*args, **kwargs)
# continue here
```
the `my_new_function` function (and, recursively, any other new functions called in its body) will be automatically copy-pasted
in the file where it is used.
### Calling `super()` ### Calling `super()`
We recently shipped a few features that allow you to go from: We recently shipped a few features that allow you to go from:
```python ```python
@ -174,4 +194,4 @@ We now also support special cases like
class GemmaVisionModel(CLIPModel): class GemmaVisionModel(CLIPModel):
pass pass
``` ```
where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models.

View File

@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>. # This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_gemma.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>. # This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_gemma.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>. # This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_gemma.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>. # This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_gemma2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>. # This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_gemma2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>. # This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_instructblipvideo.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>. # This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_instructblipvideo.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>. # This file was automatically generated from src/transformers/models/llava_next_video/modular_llava_next_video.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_llava_next_video.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
# #

View File

@ -1,9 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>. # This file was automatically generated from src/transformers/models/llava_next_video/modular_llava_next_video.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_llava_next_video.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 HuggingFace Inc. team. All rights reserved. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
# #

View File

@ -15,8 +15,9 @@
import argparse import argparse
import glob import glob
import importlib import importlib
import os
import re import re
from collections import defaultdict from collections import defaultdict, deque
from typing import Dict, List, Set from typing import Dict, List, Set
import libcst as cst import libcst as cst
@ -33,12 +34,19 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This is used to avoid overwriting these top-level assignments even if they are in the dependency graph. Otherwise, the
# This file was automatically generated from <path_to_modular_file.py>. # value from the dependency is used, then mapped to current name convention, resulting in wrong value.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # The corresponding mapped value is used to define the file target for the assignment
# the file from the modular. If any change should be done, please apply the change to the ASSIGNMENTS_TO_KEEP = {
# modular_xxx.py file directly. One of our CI enforces this "_CHECKPOINT_FOR_DOC": "modeling",
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 }
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from {relative_path}.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# {short_name} file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
""" """
@ -114,11 +122,14 @@ class ClassFinder(CSTVisitor):
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
): ):
if hasattr(node.body[0].targets[0].target, "value"): left_hand_side = node.body[0].targets[0].target
self.assignments[node.body[0].targets[0].target.value] = node if hasattr(left_hand_side, "value"):
if left_hand_side.value not in ASSIGNMENTS_TO_KEEP.keys():
self.assignments[left_hand_side.value] = node
else: else:
for idx, target in enumerate(list(node.body[0].targets[0].target.elements)): for idx, target in enumerate(list(left_hand_side.elements)):
self.assignments[target.value.value] = node.body[0].value.elements[idx].value if target.value.value not in ASSIGNMENTS_TO_KEEP.keys():
self.assignments[target.value.value] = node.body[0].value.elements[idx].value
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
self.imports[node.body[0].names] = node self.imports[node.body[0].names] = node
@ -612,6 +623,99 @@ def get_new_part(class_name, base_class):
return snake_case return snake_case
def find_all_dependencies(function: str, dependency_mapping: dict[str, set]):
"""Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file:
```
def foo1():
pass
def foo2():
pass
def bar():
foo1()
def foobar():
bar()
foo2()
class MyLayer(SomeOtherModelLayer):
def forward(...):
foobar()
```
and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get:
```
dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}}
find_all_dependencies('foobar', dependency_mapping)
>>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')]
```
That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can
work correctly.
"""
all_dependencies = deque(dependency_mapping[function])
all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]]
checked_dependencies = set(function)
while len(all_dependencies) > 0:
# Pick element to visit
parent = all_dependencies.popleft()
if parent not in checked_dependencies:
# Update dependencies
all_dependencies.extend(dependency_mapping[parent])
all_dependencies_with_parent += [(dependency, parent) for dependency in dependency_mapping[parent]]
# add visited node to the list
checked_dependencies.add(parent)
# no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later)
return all_dependencies_with_parent
class PostModularConverterCleaner(CSTTransformer):
"""Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due
to dependency mapping, even if code parts with those functions/classes were overwritten)"""
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, added_dependencies: set):
super().__init__()
self.top_level_functions_or_classes = {}
self.all_used_functions_or_classes = set()
self.added_dependencies = added_dependencies
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.top_level_functions_or_classes[node.name.value] = node
def visit_ClassDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.top_level_functions_or_classes[node.name.value] = node
def visit_Name(self, node: cst.Name):
"""This is used to find any mention of a top-level function or class except its own definition.
It will contain other names as well, but those will not be used. This is the most general way to do it
since mentions may appear in a lot of different contexts (apart from simple Call to the function/class).
e.g. Attention classes are only mentionned by their name in a dict assignment.
"""
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if not (
(m.matches(parent_node, m.ClassDef()) and parent_node.name.value == node.value)
or (m.matches(parent_node, m.FunctionDef()) and parent_node.name.value == node.value)
):
self.all_used_functions_or_classes.add(node.value)
def leave_Module(self, original_node: cst.Module, node):
# Find any class/function that was mistakenly added as part of the dependencies and remove it
unused = self.added_dependencies - self.all_used_functions_or_classes
nodes_to_remove = [
self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes
]
new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove]
# Return a new module with the updated body
return node.with_changes(body=new_body)
class ModularConverterTransformer(CSTTransformer): class ModularConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
@ -643,6 +747,13 @@ class ModularConverterTransformer(CSTTransformer):
self.match_patterns = "|".join(self.files.keys()) self.match_patterns = "|".join(self.files.keys())
self.all_definitions = {} self.all_definitions = {}
self.class_to_file_type = {} self.class_to_file_type = {}
self.current_class = None # keep track of current top-level class during visit
self.current_top_level_function = None # keep track of current top-level function during visit
# Mapping from top-level functions to classes using them
self.function_call_class_mapping = defaultdict(lambda: set())
# Mapping from top-level functions to other top-level functions dependencies
self.function_call_dependency_mapping = defaultdict(lambda: set())
self.added_dependencies = set()
def visit_ImportFrom(self, node: cst.ImportFrom) -> None: def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from `transformers.models.xxx` we need to: """When visiting imports from `transformers.models.xxx` we need to:
@ -692,9 +803,20 @@ class ModularConverterTransformer(CSTTransformer):
if updated_node not in self.all_imports: if updated_node not in self.all_imports:
self.all_imports.append(updated_node) self.all_imports.append(updated_node)
return updated_node return updated_node
elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys():
file_ = ASSIGNMENTS_TO_KEEP[original_node.body[0].targets[0].target.value]
self.files[file_][original_node.body[0].targets[0].target.value] = {
"node": original_node,
"insert_idx": self.global_scope_index,
}
self.global_scope_index += 100 self.global_scope_index += 100
return updated_node return updated_node
def visit_ClassDef(self, node: cst.ClassDef):
"""Used to keep track of current class"""
self.current_class = node.name.value
def leave_ClassDef(self, original_node, updated_node): def leave_ClassDef(self, original_node, updated_node):
""" """
1. Filter the `base` classes of this class 1. Filter the `base` classes of this class
@ -772,9 +894,10 @@ class ModularConverterTransformer(CSTTransformer):
node = class_finder.global_nodes.get(dependency, None) node = class_finder.global_nodes.get(dependency, None)
if node is not None: if node is not None:
if dependency not in file_to_update: if dependency not in file_to_update:
node = self.all_definitions.get(dependency, node) node = self.all_definitions.pop(dependency, node)
start_insert_idx -= 1 start_insert_idx -= 1
file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node} file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
self.added_dependencies.add(dependency)
elif dependency not in self.inserted_deps: elif dependency not in self.inserted_deps:
# make sure the node is written after its dependencies # make sure the node is written after its dependencies
start_insert_idx = file_to_update[dependency]["insert_idx"] - 1 start_insert_idx = file_to_update[dependency]["insert_idx"] - 1
@ -811,8 +934,15 @@ class ModularConverterTransformer(CSTTransformer):
else: else:
self.class_to_file_type[class_name] = "modeling" self.class_to_file_type[class_name] = "modeling"
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
self.current_class = None
return updated_node return updated_node
def visit_FunctionDef(self, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if m.matches(parent_node, m.Module()):
self.current_top_level_function = node.name.value
def leave_FunctionDef(self, original_node, node): def leave_FunctionDef(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()): if m.matches(parent_node, m.Module()):
@ -852,7 +982,71 @@ class ModularConverterTransformer(CSTTransformer):
logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}") logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}")
return node return node
def leave_Module(self, original_node: cst.Assign, node): def visit_Call(self, node: cst.Call):
"""This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them.
Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function,
add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible."""
# Only map function calls if we're inside a class (i.e., current_class is set)
if self.current_class is not None:
# Simple function calls such as foo()
if isinstance(node.func, cst.Name):
self.function_call_class_mapping[node.func.value].add(self.current_class)
elif self.current_top_level_function is not None:
# Simple function calls such as foo()
if isinstance(node.func, cst.Name):
self.function_call_dependency_mapping[self.current_top_level_function].add(node.func.value)
def _maybe_add_function_to_body(
self,
top_level_function: str,
body: dict,
function_node: cst.FunctionDef,
matching_callers: set | None = None,
parent: str | None = None,
) -> bool:
"""Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers`
is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return
`True`. Return `False` otherwise.
"""
if matching_callers is None and parent is None:
raise ValueError("Cannot add function if both the parent and the matching callers are None.")
if matching_callers is None:
matching_callers = {parent}
if len(matching_callers) > 0 and top_level_function not in body.keys():
# Add the function just before the first class using it
new_idx = min([body[element]["insert_idx"] for element in matching_callers])
# Reorder the elements
for element in body.keys():
if body[element]["insert_idx"] >= new_idx:
body[element]["insert_idx"] += 1
# Assign new element to body (after changing the count to avoid messing it)
body[top_level_function] = {"insert_idx": new_idx, "node": function_node}
return True
return False
def _recursively_add_all_new_needed_functions_in_files(self):
"""For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in
the different files, and add them to the file if it is the case (also recursively adding all other functions that
may be needed in that function body)."""
# At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py`
for top_level_function, function_node in self.all_definitions.items():
calling_entities = self.function_call_class_mapping[top_level_function]
# The function may be needed in different files, we need to iterate on them
for file, body in self.files.items():
file_elements = set(body.keys())
# If the intersection is not null, top_level_func must be added to file
matching_callers = calling_entities & file_elements
added = self._maybe_add_function_to_body(top_level_function, body, function_node, matching_callers)
# If the function was added, we need to recursively add all its dependencies
if added:
for dependency, parent in find_all_dependencies(
top_level_function, self.function_call_dependency_mapping
):
self._maybe_add_function_to_body(
dependency, body, self.all_definitions[dependency], parent=parent
)
def leave_Module(self, original_node: cst.Module, node):
imports = {self.python_module.code_for_node(k): k for k in self.all_imports} imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
dependency_imports = {file_type: imports.copy() for file_type in self.files} dependency_imports = {file_type: imports.copy() for file_type in self.files}
for super_file_name, visiter in self.visited_module.items(): for super_file_name, visiter in self.visited_module.items():
@ -861,12 +1055,19 @@ class ModularConverterTransformer(CSTTransformer):
{self.python_module.code_for_node(k): k for k in visiter.imports.values()} {self.python_module.code_for_node(k): k for k in visiter.imports.values()}
) )
# Check if any new top-level function from the `modular_xxx.py` should be added to the different files
# (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file).
self._recursively_add_all_new_needed_functions_in_files()
for file, body in self.files.items(): for file, body in self.files.items():
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
if len(new_body) > 0: if len(new_body) > 0:
if file in dependency_imports.keys(): if file in dependency_imports.keys():
new_body = list(dependency_imports[file].values()) + new_body new_body = list(dependency_imports[file].values()) + new_body
self.files[file] = cst.Module(body=[*new_body], header=node.header) new_module = cst.Module(body=[*new_body], header=node.header)
# Final cleanup
new_module = MetadataWrapper(new_module).visit(PostModularConverterCleaner(self.added_dependencies))
self.files[file] = new_module
return node return node
@ -885,7 +1086,14 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
wrapper.visit(cst_transformers) wrapper.visit(cst_transformers)
for file, node in cst_transformers.files.items(): for file, node in cst_transformers.files.items():
if node != {}: if node != {}:
ruffed_code = run_ruff(AUTO_GENERATED_MESSAGE + node.code, True) # Get relative path starting from src/transformers/
relative_path = re.search(
f"{os.sep}(src{os.sep}transformers{os.sep}.*)", os.path.abspath(modular_file)
).group(1)
header = AUTO_GENERATED_MESSAGE.format(
relative_path=relative_path, short_name=os.path.basename(relative_path)
)
ruffed_code = run_ruff(header + node.code, True)
formatted_code = run_ruff(ruffed_code, False) formatted_code = run_ruff(ruffed_code, False)
output[file] = [formatted_code, ruffed_code] output[file] = [formatted_code, ruffed_code]
return output return output
@ -916,7 +1124,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--files_to_parse", "--files_to_parse",
default=["src/transformers/models/gemma/modular_gemma.py"], default=["src/transformers/models/roberta/modular_roberta.py"],
nargs="+", nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file", help="A list of `modular_xxxx` files that should be converted to single model file",
) )