mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Improve modular converter (#33991)
* improve modular * style * Update modular_model_converter.py * pretty print warning * style * Support to remove unused classes as part of added dependencies as well * nits * correct bug * add example * style * Add documentation
This commit is contained in:
parent
fb360a6c7a
commit
17806d11ba
@ -52,6 +52,7 @@ For example:
|
|||||||
reference it (in case of addition) or completely remove it (in case of deletion).
|
reference it (in case of addition) or completely remove it (in case of deletion).
|
||||||
- If a class inherits from another, for example: class GemmaModel(LlamaModel):, dependencies are automatically
|
- If a class inherits from another, for example: class GemmaModel(LlamaModel):, dependencies are automatically
|
||||||
inferred. All submodules will be automatically inferred from the superclass.
|
inferred. All submodules will be automatically inferred from the superclass.
|
||||||
|
- If you define new functions in the `modular` and use them inside classes, the linter will automatically infer the
|
||||||
|
|
||||||
You should be able to write everything (the tokenizer, the image processor, the model, the config) in this `modular`
|
You should be able to write everything (the tokenizer, the image processor, the model, the config) in this `modular`
|
||||||
file, and the corresponding files will be created for you.
|
file, and the corresponding files will be created for you.
|
||||||
@ -158,6 +159,25 @@ class GemmaTokenizer(LlamaTokenizer):
|
|||||||
raise AttributeError("Not needed for Gemma")
|
raise AttributeError("Not needed for Gemma")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Define new functions
|
||||||
|
|
||||||
|
If you define a new function in the `modular` file to be used inside a class, say
|
||||||
|
|
||||||
|
```python
|
||||||
|
def my_new_function(*args, **kwargs):
|
||||||
|
# Do something here
|
||||||
|
pass
|
||||||
|
|
||||||
|
class GemmaModel(LlamaModel):
|
||||||
|
def forward(*args, **kwargs):
|
||||||
|
# Call the function
|
||||||
|
example = my_new_function(*args, **kwargs)
|
||||||
|
# continue here
|
||||||
|
```
|
||||||
|
|
||||||
|
the `my_new_function` function (and, recursively, any other new functions called in its body) will be automatically copy-pasted
|
||||||
|
in the file where it is used.
|
||||||
|
|
||||||
### Calling `super()`
|
### Calling `super()`
|
||||||
We recently shipped a few features that allow you to go from:
|
We recently shipped a few features that allow you to go from:
|
||||||
```python
|
```python
|
||||||
@ -174,4 +194,4 @@ We now also support special cases like
|
|||||||
class GemmaVisionModel(CLIPModel):
|
class GemmaVisionModel(CLIPModel):
|
||||||
pass
|
pass
|
||||||
```
|
```
|
||||||
where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models
|
where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models.
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_gemma.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_gemma.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_gemma.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_gemma2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_gemma2.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_instructblipvideo.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_instructblipvideo.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from src/transformers/models/llava_next_video/modular_llava_next_video.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_llava_next_video.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# This file was automatically generated from src/transformers/models/llava_next_video/modular_llava_next_video.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_llava_next_video.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
|
@ -15,8 +15,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict, deque
|
||||||
from typing import Dict, List, Set
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
@ -33,12 +34,19 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# This is used to avoid overwriting these top-level assignments even if they are in the dependency graph. Otherwise, the
|
||||||
# This file was automatically generated from <path_to_modular_file.py>.
|
# value from the dependency is used, then mapped to current name convention, resulting in wrong value.
|
||||||
|
# The corresponding mapped value is used to define the file target for the assignment
|
||||||
|
ASSIGNMENTS_TO_KEEP = {
|
||||||
|
"_CHECKPOINT_FOR_DOC": "modeling",
|
||||||
|
}
|
||||||
|
|
||||||
|
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from {relative_path}.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# {short_name} file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -114,10 +122,13 @@ class ClassFinder(CSTVisitor):
|
|||||||
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
|
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
|
||||||
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
|
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
|
||||||
):
|
):
|
||||||
if hasattr(node.body[0].targets[0].target, "value"):
|
left_hand_side = node.body[0].targets[0].target
|
||||||
self.assignments[node.body[0].targets[0].target.value] = node
|
if hasattr(left_hand_side, "value"):
|
||||||
|
if left_hand_side.value not in ASSIGNMENTS_TO_KEEP.keys():
|
||||||
|
self.assignments[left_hand_side.value] = node
|
||||||
else:
|
else:
|
||||||
for idx, target in enumerate(list(node.body[0].targets[0].target.elements)):
|
for idx, target in enumerate(list(left_hand_side.elements)):
|
||||||
|
if target.value.value not in ASSIGNMENTS_TO_KEEP.keys():
|
||||||
self.assignments[target.value.value] = node.body[0].value.elements[idx].value
|
self.assignments[target.value.value] = node.body[0].value.elements[idx].value
|
||||||
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
|
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
|
||||||
self.imports[node.body[0].names] = node
|
self.imports[node.body[0].names] = node
|
||||||
@ -612,6 +623,99 @@ def get_new_part(class_name, base_class):
|
|||||||
return snake_case
|
return snake_case
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_dependencies(function: str, dependency_mapping: dict[str, set]):
|
||||||
|
"""Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file:
|
||||||
|
```
|
||||||
|
def foo1():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def foo2():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def bar():
|
||||||
|
foo1()
|
||||||
|
|
||||||
|
def foobar():
|
||||||
|
bar()
|
||||||
|
foo2()
|
||||||
|
|
||||||
|
class MyLayer(SomeOtherModelLayer):
|
||||||
|
def forward(...):
|
||||||
|
foobar()
|
||||||
|
```
|
||||||
|
and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get:
|
||||||
|
```
|
||||||
|
dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}}
|
||||||
|
find_all_dependencies('foobar', dependency_mapping)
|
||||||
|
>>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')]
|
||||||
|
```
|
||||||
|
That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can
|
||||||
|
work correctly.
|
||||||
|
"""
|
||||||
|
all_dependencies = deque(dependency_mapping[function])
|
||||||
|
all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]]
|
||||||
|
checked_dependencies = set(function)
|
||||||
|
while len(all_dependencies) > 0:
|
||||||
|
# Pick element to visit
|
||||||
|
parent = all_dependencies.popleft()
|
||||||
|
if parent not in checked_dependencies:
|
||||||
|
# Update dependencies
|
||||||
|
all_dependencies.extend(dependency_mapping[parent])
|
||||||
|
all_dependencies_with_parent += [(dependency, parent) for dependency in dependency_mapping[parent]]
|
||||||
|
# add visited node to the list
|
||||||
|
checked_dependencies.add(parent)
|
||||||
|
|
||||||
|
# no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later)
|
||||||
|
return all_dependencies_with_parent
|
||||||
|
|
||||||
|
|
||||||
|
class PostModularConverterCleaner(CSTTransformer):
|
||||||
|
"""Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due
|
||||||
|
to dependency mapping, even if code parts with those functions/classes were overwritten)"""
|
||||||
|
|
||||||
|
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||||
|
|
||||||
|
def __init__(self, added_dependencies: set):
|
||||||
|
super().__init__()
|
||||||
|
self.top_level_functions_or_classes = {}
|
||||||
|
self.all_used_functions_or_classes = set()
|
||||||
|
self.added_dependencies = added_dependencies
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||||
|
if m.matches(parent_node, m.Module()):
|
||||||
|
self.top_level_functions_or_classes[node.name.value] = node
|
||||||
|
|
||||||
|
def visit_ClassDef(self, node):
|
||||||
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||||
|
if m.matches(parent_node, m.Module()):
|
||||||
|
self.top_level_functions_or_classes[node.name.value] = node
|
||||||
|
|
||||||
|
def visit_Name(self, node: cst.Name):
|
||||||
|
"""This is used to find any mention of a top-level function or class except its own definition.
|
||||||
|
It will contain other names as well, but those will not be used. This is the most general way to do it
|
||||||
|
since mentions may appear in a lot of different contexts (apart from simple Call to the function/class).
|
||||||
|
e.g. Attention classes are only mentionned by their name in a dict assignment.
|
||||||
|
"""
|
||||||
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||||
|
|
||||||
|
if not (
|
||||||
|
(m.matches(parent_node, m.ClassDef()) and parent_node.name.value == node.value)
|
||||||
|
or (m.matches(parent_node, m.FunctionDef()) and parent_node.name.value == node.value)
|
||||||
|
):
|
||||||
|
self.all_used_functions_or_classes.add(node.value)
|
||||||
|
|
||||||
|
def leave_Module(self, original_node: cst.Module, node):
|
||||||
|
# Find any class/function that was mistakenly added as part of the dependencies and remove it
|
||||||
|
unused = self.added_dependencies - self.all_used_functions_or_classes
|
||||||
|
nodes_to_remove = [
|
||||||
|
self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes
|
||||||
|
]
|
||||||
|
new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove]
|
||||||
|
# Return a new module with the updated body
|
||||||
|
return node.with_changes(body=new_body)
|
||||||
|
|
||||||
|
|
||||||
class ModularConverterTransformer(CSTTransformer):
|
class ModularConverterTransformer(CSTTransformer):
|
||||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||||
|
|
||||||
@ -643,6 +747,13 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
self.match_patterns = "|".join(self.files.keys())
|
self.match_patterns = "|".join(self.files.keys())
|
||||||
self.all_definitions = {}
|
self.all_definitions = {}
|
||||||
self.class_to_file_type = {}
|
self.class_to_file_type = {}
|
||||||
|
self.current_class = None # keep track of current top-level class during visit
|
||||||
|
self.current_top_level_function = None # keep track of current top-level function during visit
|
||||||
|
# Mapping from top-level functions to classes using them
|
||||||
|
self.function_call_class_mapping = defaultdict(lambda: set())
|
||||||
|
# Mapping from top-level functions to other top-level functions dependencies
|
||||||
|
self.function_call_dependency_mapping = defaultdict(lambda: set())
|
||||||
|
self.added_dependencies = set()
|
||||||
|
|
||||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||||
@ -692,9 +803,20 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
if updated_node not in self.all_imports:
|
if updated_node not in self.all_imports:
|
||||||
self.all_imports.append(updated_node)
|
self.all_imports.append(updated_node)
|
||||||
return updated_node
|
return updated_node
|
||||||
|
elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])):
|
||||||
|
if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys():
|
||||||
|
file_ = ASSIGNMENTS_TO_KEEP[original_node.body[0].targets[0].target.value]
|
||||||
|
self.files[file_][original_node.body[0].targets[0].target.value] = {
|
||||||
|
"node": original_node,
|
||||||
|
"insert_idx": self.global_scope_index,
|
||||||
|
}
|
||||||
self.global_scope_index += 100
|
self.global_scope_index += 100
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
|
def visit_ClassDef(self, node: cst.ClassDef):
|
||||||
|
"""Used to keep track of current class"""
|
||||||
|
self.current_class = node.name.value
|
||||||
|
|
||||||
def leave_ClassDef(self, original_node, updated_node):
|
def leave_ClassDef(self, original_node, updated_node):
|
||||||
"""
|
"""
|
||||||
1. Filter the `base` classes of this class
|
1. Filter the `base` classes of this class
|
||||||
@ -772,9 +894,10 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
node = class_finder.global_nodes.get(dependency, None)
|
node = class_finder.global_nodes.get(dependency, None)
|
||||||
if node is not None:
|
if node is not None:
|
||||||
if dependency not in file_to_update:
|
if dependency not in file_to_update:
|
||||||
node = self.all_definitions.get(dependency, node)
|
node = self.all_definitions.pop(dependency, node)
|
||||||
start_insert_idx -= 1
|
start_insert_idx -= 1
|
||||||
file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
|
file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node}
|
||||||
|
self.added_dependencies.add(dependency)
|
||||||
elif dependency not in self.inserted_deps:
|
elif dependency not in self.inserted_deps:
|
||||||
# make sure the node is written after its dependencies
|
# make sure the node is written after its dependencies
|
||||||
start_insert_idx = file_to_update[dependency]["insert_idx"] - 1
|
start_insert_idx = file_to_update[dependency]["insert_idx"] - 1
|
||||||
@ -811,8 +934,15 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
else:
|
else:
|
||||||
self.class_to_file_type[class_name] = "modeling"
|
self.class_to_file_type[class_name] = "modeling"
|
||||||
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||||
|
|
||||||
|
self.current_class = None
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||||
|
if m.matches(parent_node, m.Module()):
|
||||||
|
self.current_top_level_function = node.name.value
|
||||||
|
|
||||||
def leave_FunctionDef(self, original_node, node):
|
def leave_FunctionDef(self, original_node, node):
|
||||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||||
if m.matches(parent_node, m.Module()):
|
if m.matches(parent_node, m.Module()):
|
||||||
@ -852,7 +982,71 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}")
|
logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}")
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def leave_Module(self, original_node: cst.Assign, node):
|
def visit_Call(self, node: cst.Call):
|
||||||
|
"""This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them.
|
||||||
|
Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function,
|
||||||
|
add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible."""
|
||||||
|
# Only map function calls if we're inside a class (i.e., current_class is set)
|
||||||
|
if self.current_class is not None:
|
||||||
|
# Simple function calls such as foo()
|
||||||
|
if isinstance(node.func, cst.Name):
|
||||||
|
self.function_call_class_mapping[node.func.value].add(self.current_class)
|
||||||
|
elif self.current_top_level_function is not None:
|
||||||
|
# Simple function calls such as foo()
|
||||||
|
if isinstance(node.func, cst.Name):
|
||||||
|
self.function_call_dependency_mapping[self.current_top_level_function].add(node.func.value)
|
||||||
|
|
||||||
|
def _maybe_add_function_to_body(
|
||||||
|
self,
|
||||||
|
top_level_function: str,
|
||||||
|
body: dict,
|
||||||
|
function_node: cst.FunctionDef,
|
||||||
|
matching_callers: set | None = None,
|
||||||
|
parent: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers`
|
||||||
|
is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return
|
||||||
|
`True`. Return `False` otherwise.
|
||||||
|
"""
|
||||||
|
if matching_callers is None and parent is None:
|
||||||
|
raise ValueError("Cannot add function if both the parent and the matching callers are None.")
|
||||||
|
if matching_callers is None:
|
||||||
|
matching_callers = {parent}
|
||||||
|
if len(matching_callers) > 0 and top_level_function not in body.keys():
|
||||||
|
# Add the function just before the first class using it
|
||||||
|
new_idx = min([body[element]["insert_idx"] for element in matching_callers])
|
||||||
|
# Reorder the elements
|
||||||
|
for element in body.keys():
|
||||||
|
if body[element]["insert_idx"] >= new_idx:
|
||||||
|
body[element]["insert_idx"] += 1
|
||||||
|
# Assign new element to body (after changing the count to avoid messing it)
|
||||||
|
body[top_level_function] = {"insert_idx": new_idx, "node": function_node}
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _recursively_add_all_new_needed_functions_in_files(self):
|
||||||
|
"""For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in
|
||||||
|
the different files, and add them to the file if it is the case (also recursively adding all other functions that
|
||||||
|
may be needed in that function body)."""
|
||||||
|
# At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py`
|
||||||
|
for top_level_function, function_node in self.all_definitions.items():
|
||||||
|
calling_entities = self.function_call_class_mapping[top_level_function]
|
||||||
|
# The function may be needed in different files, we need to iterate on them
|
||||||
|
for file, body in self.files.items():
|
||||||
|
file_elements = set(body.keys())
|
||||||
|
# If the intersection is not null, top_level_func must be added to file
|
||||||
|
matching_callers = calling_entities & file_elements
|
||||||
|
added = self._maybe_add_function_to_body(top_level_function, body, function_node, matching_callers)
|
||||||
|
# If the function was added, we need to recursively add all its dependencies
|
||||||
|
if added:
|
||||||
|
for dependency, parent in find_all_dependencies(
|
||||||
|
top_level_function, self.function_call_dependency_mapping
|
||||||
|
):
|
||||||
|
self._maybe_add_function_to_body(
|
||||||
|
dependency, body, self.all_definitions[dependency], parent=parent
|
||||||
|
)
|
||||||
|
|
||||||
|
def leave_Module(self, original_node: cst.Module, node):
|
||||||
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
||||||
dependency_imports = {file_type: imports.copy() for file_type in self.files}
|
dependency_imports = {file_type: imports.copy() for file_type in self.files}
|
||||||
for super_file_name, visiter in self.visited_module.items():
|
for super_file_name, visiter in self.visited_module.items():
|
||||||
@ -861,12 +1055,19 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
{self.python_module.code_for_node(k): k for k in visiter.imports.values()}
|
{self.python_module.code_for_node(k): k for k in visiter.imports.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if any new top-level function from the `modular_xxx.py` should be added to the different files
|
||||||
|
# (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file).
|
||||||
|
self._recursively_add_all_new_needed_functions_in_files()
|
||||||
|
|
||||||
for file, body in self.files.items():
|
for file, body in self.files.items():
|
||||||
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
|
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])]
|
||||||
if len(new_body) > 0:
|
if len(new_body) > 0:
|
||||||
if file in dependency_imports.keys():
|
if file in dependency_imports.keys():
|
||||||
new_body = list(dependency_imports[file].values()) + new_body
|
new_body = list(dependency_imports[file].values()) + new_body
|
||||||
self.files[file] = cst.Module(body=[*new_body], header=node.header)
|
new_module = cst.Module(body=[*new_body], header=node.header)
|
||||||
|
# Final cleanup
|
||||||
|
new_module = MetadataWrapper(new_module).visit(PostModularConverterCleaner(self.added_dependencies))
|
||||||
|
self.files[file] = new_module
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
@ -885,7 +1086,14 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
|
|||||||
wrapper.visit(cst_transformers)
|
wrapper.visit(cst_transformers)
|
||||||
for file, node in cst_transformers.files.items():
|
for file, node in cst_transformers.files.items():
|
||||||
if node != {}:
|
if node != {}:
|
||||||
ruffed_code = run_ruff(AUTO_GENERATED_MESSAGE + node.code, True)
|
# Get relative path starting from src/transformers/
|
||||||
|
relative_path = re.search(
|
||||||
|
f"{os.sep}(src{os.sep}transformers{os.sep}.*)", os.path.abspath(modular_file)
|
||||||
|
).group(1)
|
||||||
|
header = AUTO_GENERATED_MESSAGE.format(
|
||||||
|
relative_path=relative_path, short_name=os.path.basename(relative_path)
|
||||||
|
)
|
||||||
|
ruffed_code = run_ruff(header + node.code, True)
|
||||||
formatted_code = run_ruff(ruffed_code, False)
|
formatted_code = run_ruff(ruffed_code, False)
|
||||||
output[file] = [formatted_code, ruffed_code]
|
output[file] = [formatted_code, ruffed_code]
|
||||||
return output
|
return output
|
||||||
@ -916,7 +1124,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--files_to_parse",
|
"--files_to_parse",
|
||||||
default=["src/transformers/models/gemma/modular_gemma.py"],
|
default=["src/transformers/models/roberta/modular_roberta.py"],
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user