diff --git a/docs/source/en/modular_transformers.md b/docs/source/en/modular_transformers.md index dbc8d9116ed..1516233ec4d 100644 --- a/docs/source/en/modular_transformers.md +++ b/docs/source/en/modular_transformers.md @@ -52,6 +52,7 @@ For example: reference it (in case of addition) or completely remove it (in case of deletion). - If a class inherits from another, for example: class GemmaModel(LlamaModel):, dependencies are automatically inferred. All submodules will be automatically inferred from the superclass. +- If you define new functions in the `modular` and use them inside classes, the linter will automatically infer the You should be able to write everything (the tokenizer, the image processor, the model, the config) in this `modular` file, and the corresponding files will be created for you. @@ -158,6 +159,25 @@ class GemmaTokenizer(LlamaTokenizer): raise AttributeError("Not needed for Gemma") ``` +### Define new functions + +If you define a new function in the `modular` file to be used inside a class, say + +```python +def my_new_function(*args, **kwargs): + # Do something here + pass + +class GemmaModel(LlamaModel): + def forward(*args, **kwargs): + # Call the function + example = my_new_function(*args, **kwargs) + # continue here +``` + +the `my_new_function` function (and, recursively, any other new functions called in its body) will be automatically copy-pasted +in the file where it is used. + ### Calling `super()` We recently shipped a few features that allow you to go from: ```python @@ -174,4 +194,4 @@ We now also support special cases like class GemmaVisionModel(CLIPModel): pass ``` -where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models \ No newline at end of file +where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models. \ No newline at end of file diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 90255086eef..e170803ccca 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index df7f38014a1..ff206a470bc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # diff --git a/src/transformers/models/gemma/tokenization_gemma.py b/src/transformers/models/gemma/tokenization_gemma.py index 5233037262f..ff0d1d034c2 100644 --- a/src/transformers/models/gemma/tokenization_gemma.py +++ b/src/transformers/models/gemma/tokenization_gemma.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma/modular_gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 44f96efb6df..74976bdd340 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 8498d5525ef..d8c75871906 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/gemma2/modular_gemma2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_gemma2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. # diff --git a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py index 02672bdce83..e7c8eeccef9 100644 --- a/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/configuration_instructblipvideo.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_instructblipvideo.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. # diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index c3a2c7add30..a300268ed71 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/instructblipvideo/modular_instructblipvideo.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_instructblipvideo.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. # diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py index 1631c101830..0e4e39b4b3a 100644 --- a/src/transformers/models/llava_next_video/configuration_llava_next_video.py +++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/llava_next_video/modular_llava_next_video.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_llava_next_video.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. # diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index ea1114df7c2..58fed183267 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -1,9 +1,9 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/llava_next_video/modular_llava_next_video.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_llava_next_video.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. # diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index cc3089da3f3..599dc70e17e 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -15,8 +15,9 @@ import argparse import glob import importlib +import os import re -from collections import defaultdict +from collections import defaultdict, deque from typing import Dict, List, Set import libcst as cst @@ -33,12 +34,19 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES logger = logging.get_logger(__name__) -AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_xxx.py file directly. One of our CI enforces this -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This is used to avoid overwriting these top-level assignments even if they are in the dependency graph. Otherwise, the +# value from the dependency is used, then mapped to current name convention, resulting in wrong value. +# The corresponding mapped value is used to define the file target for the assignment +ASSIGNMENTS_TO_KEEP = { + "_CHECKPOINT_FOR_DOC": "modeling", +} + +AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from {relative_path}. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# {short_name} file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 """ @@ -114,11 +122,14 @@ class ClassFinder(CSTVisitor): if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() ): - if hasattr(node.body[0].targets[0].target, "value"): - self.assignments[node.body[0].targets[0].target.value] = node + left_hand_side = node.body[0].targets[0].target + if hasattr(left_hand_side, "value"): + if left_hand_side.value not in ASSIGNMENTS_TO_KEEP.keys(): + self.assignments[left_hand_side.value] = node else: - for idx, target in enumerate(list(node.body[0].targets[0].target.elements)): - self.assignments[target.value.value] = node.body[0].value.elements[idx].value + for idx, target in enumerate(list(left_hand_side.elements)): + if target.value.value not in ASSIGNMENTS_TO_KEEP.keys(): + self.assignments[target.value.value] = node.body[0].value.elements[idx].value if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): self.imports[node.body[0].names] = node @@ -612,6 +623,99 @@ def get_new_part(class_name, base_class): return snake_case +def find_all_dependencies(function: str, dependency_mapping: dict[str, set]): + """Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file: + ``` + def foo1(): + pass + + def foo2(): + pass + + def bar(): + foo1() + + def foobar(): + bar() + foo2() + + class MyLayer(SomeOtherModelLayer): + def forward(...): + foobar() + ``` + and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: + ``` + dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} + find_all_dependencies('foobar', dependency_mapping) + >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] + ``` + That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can + work correctly. + """ + all_dependencies = deque(dependency_mapping[function]) + all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]] + checked_dependencies = set(function) + while len(all_dependencies) > 0: + # Pick element to visit + parent = all_dependencies.popleft() + if parent not in checked_dependencies: + # Update dependencies + all_dependencies.extend(dependency_mapping[parent]) + all_dependencies_with_parent += [(dependency, parent) for dependency in dependency_mapping[parent]] + # add visited node to the list + checked_dependencies.add(parent) + + # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) + return all_dependencies_with_parent + + +class PostModularConverterCleaner(CSTTransformer): + """Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due + to dependency mapping, even if code parts with those functions/classes were overwritten)""" + + METADATA_DEPENDENCIES = (ParentNodeProvider,) + + def __init__(self, added_dependencies: set): + super().__init__() + self.top_level_functions_or_classes = {} + self.all_used_functions_or_classes = set() + self.added_dependencies = added_dependencies + + def visit_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.top_level_functions_or_classes[node.name.value] = node + + def visit_ClassDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.top_level_functions_or_classes[node.name.value] = node + + def visit_Name(self, node: cst.Name): + """This is used to find any mention of a top-level function or class except its own definition. + It will contain other names as well, but those will not be used. This is the most general way to do it + since mentions may appear in a lot of different contexts (apart from simple Call to the function/class). + e.g. Attention classes are only mentionned by their name in a dict assignment. + """ + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + + if not ( + (m.matches(parent_node, m.ClassDef()) and parent_node.name.value == node.value) + or (m.matches(parent_node, m.FunctionDef()) and parent_node.name.value == node.value) + ): + self.all_used_functions_or_classes.add(node.value) + + def leave_Module(self, original_node: cst.Module, node): + # Find any class/function that was mistakenly added as part of the dependencies and remove it + unused = self.added_dependencies - self.all_used_functions_or_classes + nodes_to_remove = [ + self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes + ] + new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove] + # Return a new module with the updated body + return node.with_changes(body=new_body) + + class ModularConverterTransformer(CSTTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) @@ -643,6 +747,13 @@ class ModularConverterTransformer(CSTTransformer): self.match_patterns = "|".join(self.files.keys()) self.all_definitions = {} self.class_to_file_type = {} + self.current_class = None # keep track of current top-level class during visit + self.current_top_level_function = None # keep track of current top-level function during visit + # Mapping from top-level functions to classes using them + self.function_call_class_mapping = defaultdict(lambda: set()) + # Mapping from top-level functions to other top-level functions dependencies + self.function_call_dependency_mapping = defaultdict(lambda: set()) + self.added_dependencies = set() def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from `transformers.models.xxx` we need to: @@ -692,9 +803,20 @@ class ModularConverterTransformer(CSTTransformer): if updated_node not in self.all_imports: self.all_imports.append(updated_node) return updated_node + elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])): + if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys(): + file_ = ASSIGNMENTS_TO_KEEP[original_node.body[0].targets[0].target.value] + self.files[file_][original_node.body[0].targets[0].target.value] = { + "node": original_node, + "insert_idx": self.global_scope_index, + } self.global_scope_index += 100 return updated_node + def visit_ClassDef(self, node: cst.ClassDef): + """Used to keep track of current class""" + self.current_class = node.name.value + def leave_ClassDef(self, original_node, updated_node): """ 1. Filter the `base` classes of this class @@ -772,9 +894,10 @@ class ModularConverterTransformer(CSTTransformer): node = class_finder.global_nodes.get(dependency, None) if node is not None: if dependency not in file_to_update: - node = self.all_definitions.get(dependency, node) + node = self.all_definitions.pop(dependency, node) start_insert_idx -= 1 file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node} + self.added_dependencies.add(dependency) elif dependency not in self.inserted_deps: # make sure the node is written after its dependencies start_insert_idx = file_to_update[dependency]["insert_idx"] - 1 @@ -811,8 +934,15 @@ class ModularConverterTransformer(CSTTransformer): else: self.class_to_file_type[class_name] = "modeling" self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} + + self.current_class = None return updated_node + def visit_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.current_top_level_function = node.name.value + def leave_FunctionDef(self, original_node, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): @@ -852,7 +982,71 @@ class ModularConverterTransformer(CSTTransformer): logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}") return node - def leave_Module(self, original_node: cst.Assign, node): + def visit_Call(self, node: cst.Call): + """This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them. + Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, + add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible.""" + # Only map function calls if we're inside a class (i.e., current_class is set) + if self.current_class is not None: + # Simple function calls such as foo() + if isinstance(node.func, cst.Name): + self.function_call_class_mapping[node.func.value].add(self.current_class) + elif self.current_top_level_function is not None: + # Simple function calls such as foo() + if isinstance(node.func, cst.Name): + self.function_call_dependency_mapping[self.current_top_level_function].add(node.func.value) + + def _maybe_add_function_to_body( + self, + top_level_function: str, + body: dict, + function_node: cst.FunctionDef, + matching_callers: set | None = None, + parent: str | None = None, + ) -> bool: + """Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers` + is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return + `True`. Return `False` otherwise. + """ + if matching_callers is None and parent is None: + raise ValueError("Cannot add function if both the parent and the matching callers are None.") + if matching_callers is None: + matching_callers = {parent} + if len(matching_callers) > 0 and top_level_function not in body.keys(): + # Add the function just before the first class using it + new_idx = min([body[element]["insert_idx"] for element in matching_callers]) + # Reorder the elements + for element in body.keys(): + if body[element]["insert_idx"] >= new_idx: + body[element]["insert_idx"] += 1 + # Assign new element to body (after changing the count to avoid messing it) + body[top_level_function] = {"insert_idx": new_idx, "node": function_node} + return True + return False + + def _recursively_add_all_new_needed_functions_in_files(self): + """For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in + the different files, and add them to the file if it is the case (also recursively adding all other functions that + may be needed in that function body).""" + # At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py` + for top_level_function, function_node in self.all_definitions.items(): + calling_entities = self.function_call_class_mapping[top_level_function] + # The function may be needed in different files, we need to iterate on them + for file, body in self.files.items(): + file_elements = set(body.keys()) + # If the intersection is not null, top_level_func must be added to file + matching_callers = calling_entities & file_elements + added = self._maybe_add_function_to_body(top_level_function, body, function_node, matching_callers) + # If the function was added, we need to recursively add all its dependencies + if added: + for dependency, parent in find_all_dependencies( + top_level_function, self.function_call_dependency_mapping + ): + self._maybe_add_function_to_body( + dependency, body, self.all_definitions[dependency], parent=parent + ) + + def leave_Module(self, original_node: cst.Module, node): imports = {self.python_module.code_for_node(k): k for k in self.all_imports} dependency_imports = {file_type: imports.copy() for file_type in self.files} for super_file_name, visiter in self.visited_module.items(): @@ -861,12 +1055,19 @@ class ModularConverterTransformer(CSTTransformer): {self.python_module.code_for_node(k): k for k in visiter.imports.values()} ) + # Check if any new top-level function from the `modular_xxx.py` should be added to the different files + # (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file). + self._recursively_add_all_new_needed_functions_in_files() + for file, body in self.files.items(): new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] if len(new_body) > 0: if file in dependency_imports.keys(): new_body = list(dependency_imports[file].values()) + new_body - self.files[file] = cst.Module(body=[*new_body], header=node.header) + new_module = cst.Module(body=[*new_body], header=node.header) + # Final cleanup + new_module = MetadataWrapper(new_module).visit(PostModularConverterCleaner(self.added_dependencies)) + self.files[file] = new_module return node @@ -885,7 +1086,14 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, wrapper.visit(cst_transformers) for file, node in cst_transformers.files.items(): if node != {}: - ruffed_code = run_ruff(AUTO_GENERATED_MESSAGE + node.code, True) + # Get relative path starting from src/transformers/ + relative_path = re.search( + f"{os.sep}(src{os.sep}transformers{os.sep}.*)", os.path.abspath(modular_file) + ).group(1) + header = AUTO_GENERATED_MESSAGE.format( + relative_path=relative_path, short_name=os.path.basename(relative_path) + ) + ruffed_code = run_ruff(header + node.code, True) formatted_code = run_ruff(ruffed_code, False) output[file] = [formatted_code, ruffed_code] return output @@ -916,7 +1124,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/gemma/modular_gemma.py"], + default=["src/transformers/models/roberta/modular_roberta.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", )