From a3add29097b5d0aadb57020507240b29d47dd7b3 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 8 Oct 2024 10:19:17 +0200 Subject: [PATCH] Add support for __all__ and potentilly deleting functions (#33859) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add support for __all__ and potentailly deleting functions * updates * update * nits * remove dummies * fix warning * fixup * style * update * fixup * skip copied from when # skip * remove log * bring dummies back * fixup * remove copied from * fixup * remove warnings from `make fix-copies` * fix doc issues * nits * Better error message ! * add support for more flexible naming! * style * breaking style? * fix super() renaming issues * del not needed when you don't call super().__init__() * style * no more fmt on :) * properly remove `self` * fixup * fix * doc nits * add some doc 🫡 --- docs/source/en/modular_transformers.md | 58 ++++- .../modular-transformers/modeling_dummy.py | 1 - src/transformers/__init__.py | 1 - src/transformers/deepspeed.py | 41 ---- .../models/gemma/configuration_gemma.py | 3 + .../models/gemma/modeling_gemma.py | 3 + .../models/gemma/modular_gemma.py | 179 ++++++++++++++- .../models/gemma/tokenization_gemma.py | 98 ++++---- .../modeling_instructblipvideo.py | 13 -- .../models/llama/tokenization_llama.py | 4 +- .../modeling_llava_next_video.py | 10 - .../seamless_m4t/modeling_seamless_m4t.py | 2 +- .../modeling_seamless_m4t_v2.py | 2 +- utils/modular_model_converter.py | 210 +++++++++++++++--- utils/not_doctested.txt | 1 - 15 files changed, 477 insertions(+), 149 deletions(-) delete mode 100644 src/transformers/deepspeed.py diff --git a/docs/source/en/modular_transformers.md b/docs/source/en/modular_transformers.md index 33d2bb94834..dbc8d9116ed 100644 --- a/docs/source/en/modular_transformers.md +++ b/docs/source/en/modular_transformers.md @@ -118,4 +118,60 @@ Additionally, you may find a list of examples here: ## What it is not -It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual. \ No newline at end of file +It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual. + + +## Advanced usage + +### Removing attributes and functions +To remove attributes that are not used in your modular model, and that you don't want to see in the unravelled modeling: + +```python +class GemmaModel(LlamaModel): | class GemmaModel(PreTrainedModel): + def __init__(self, config): | def __init__(self, config): + super().__init__(self, eos_token) | super().__init__(config) + del self.embed_tokens | self.padding_idx = config.pad_token_id + | self.vocab_size = config.vocab_size + | + | self.layers = nn.ModuleList( + | [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + | ) + | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + | self.rotary_emb = LlamaRotaryEmbedding(config=config) + | self.gradient_checkpointing = False + | + | # Initialize weights and apply final processing + | self.post_init() +``` +If you check the original `LlamaModel`, it has a `embed_tokens` which was removed here (as you would expect!) + +Removing a function is pretty similar, you just need to write it with a `raise ValueError("")` to mimick the behaviour you actually want when you remove a parent function in python. + +```python +class GemmaTokenizer(LlamaTokenizer): + ... + + def get_spm_processor(self): + raise AttributeError("Not needed for Gemma") + + def unk_token_length(self): + raise AttributeError("Not needed for Gemma") +``` + +### Calling `super()` +We recently shipped a few features that allow you to go from: +```python +class GemmaTokenizer(LlamaTokenizer, PretrainedTokenizerFast): | class GemmaModel(nn.Module): + def __init__(self, eos_token=""): | def __init__(self): + eos_token = AddedToken(eos_token) | eos_token = AddedToken(eos_token) + PretrainedTokenizerFast.__init__(self, eos_token) | super().__init__(eos_token) +``` +This is useful want you **don't** want to unravel the call to `super()`, and you want to differentiate which super init call you are doing! + +### Special naming +We now also support special cases like +```python +class GemmaVisionModel(CLIPModel): + pass +``` +where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models \ No newline at end of file diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 51349ecf4ec..b5b1fc6aec8 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 - import math from typing import List, Optional, Tuple, Union diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e4382e04c37..f8908f7d53a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -115,7 +115,6 @@ _import_structure = { "data.metrics": [], "data.processors": [], "debug_utils": [], - "deepspeed": [], "dependency_versions_check": [], "dependency_versions_table": [], "dynamic_module_utils": [], diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py deleted file mode 100644 index 6fd22d8c5cb..00000000000 --- a/src/transformers/deepspeed.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Integration with Deepspeed - kept for backward compatiblity, if you plan to make any edit, make sure to modify the file -in `integrations/deepspeed` instead. - -Check: https://github.com/huggingface/transformers/pull/25599 -""" - -import warnings - - -warnings.warn( - "transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations", - FutureWarning, -) - -# Backward compatibility imports, to make sure all those objects can be found in integrations/deepspeed -from .integrations.deepspeed import ( # noqa - HfDeepSpeedConfig, - HfTrainerDeepSpeedConfig, - deepspeed_config, - deepspeed_init, - deepspeed_load_checkpoint, - deepspeed_optim_sched, - is_deepspeed_available, - is_deepspeed_zero3_enabled, - set_hf_deepspeed_config, - unset_hf_deepspeed_config, -) diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index 3ab61c522ef..90255086eef 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -143,3 +143,6 @@ class GemmaConfig(PretrainedConfig): tie_word_embeddings=tie_word_embeddings, **kwargs, ) + + +__all__ = ["GemmaConfig"] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8fb34e01b8e..df7f38014a1 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1314,3 +1314,6 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +__all__ = ["GemmaModel", "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification"] diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 8d7655a52dc..7130a30dc9b 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +import sentencepiece as spm import torch import torch.utils.checkpoint from torch import nn @@ -27,6 +28,7 @@ from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging from ..llama.modeling_llama import ( LlamaDecoderLayer, @@ -38,6 +40,15 @@ from ..llama.modeling_llama import ( apply_rotary_pos_emb, repeat_kv, ) +from ..llama.tokenization_llama import LlamaTokenizer + + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +SPIECE_UNDERLINE = "▁" logger = logging.get_logger(__name__) @@ -164,6 +175,162 @@ class GemmaConfig(PretrainedConfig): ) +class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer): + """ + Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Gemma should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + """ + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + PreTrainedTokenizer.__init__( + self, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + **kwargs, + ) + + def get_spm_processor(self): + raise AttributeError("Not needed for Gemma") + + def unk_token_length(self): + raise AttributeError("Not needed for Gemma") + + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Args: + text: TextInput + Simply calls PreTrainedTokenizer's method + """ + return PreTrainedTokenizer.tokenize(self, text, **kwargs) + + def _tokenize(self, text, **kwargs): + """ + Args: + text: TextInput + Returns a tokenized string. The Gemma tokenizer never adds a prefix space. + """ + return self.sp_model.encode(text, out_type=str) + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + sub_texts = [] + current_sub_text = [] + for ids in token_ids: + if skip_special_tokens and ids in self.all_special_ids: + continue + if ids in self._added_tokens_decoder: + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + sub_texts.append(self._added_tokens_decoder[ids].content) + current_sub_text = [] + else: + current_sub_text.append(ids) + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + + if spaces_between_special_tokens: + sub_texts = " ".join(sub_texts) + else: + sub_texts = "".join(sub_texts) + + return sub_texts.replace(SPIECE_UNDERLINE, " ") + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self._added_tokens_encoder: + out_string += self.sp_model.decode(current_sub_tokens) + token + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -874,3 +1041,13 @@ class GemmaForTokenClassification(LlamaForTokenClassification): super().__init__(config) self.model = GemmaModel(config) self.post_init() + + +__all__ = [ + "GemmaConfig", + "GemmaTokenizer", + "GemmaModel", + "GemmaForCausalLM", + "GemmaForSequenceClassification", + "GemmaForTokenClassification", +] diff --git a/src/transformers/models/gemma/tokenization_gemma.py b/src/transformers/models/gemma/tokenization_gemma.py index 09e779478c0..5233037262f 100644 --- a/src/transformers/models/gemma/tokenization_gemma.py +++ b/src/transformers/models/gemma/tokenization_gemma.py @@ -1,5 +1,12 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from . +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Tokenization classes for Gemma.""" - import os from shutil import copyfile from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -26,7 +30,7 @@ from ...utils import logging if TYPE_CHECKING: - pass + from ...tokenization_utils_base import TextInput logger = logging.get_logger(__name__) @@ -110,7 +114,6 @@ class GemmaTokenizer(PreTrainedTokenizer): self.add_bos_token = add_bos_token self.add_eos_token = add_eos_token self.use_default_system_prompt = use_default_system_prompt - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.Load(vocab_file) @@ -121,85 +124,60 @@ class GemmaTokenizer(PreTrainedTokenizer): pad_token=pad_token, add_bos_token=add_bos_token, add_eos_token=add_eos_token, - sp_model_kwargs=self.sp_model_kwargs, + sp_model_kwargs=sp_model_kwargs, clean_up_tokenization_spaces=clean_up_tokenization_spaces, use_default_system_prompt=use_default_system_prompt, spaces_between_special_tokens=spaces_between_special_tokens, **kwargs, ) - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.__getstate__ def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None state["sp_model_proto"] = self.sp_model.serialized_model_proto() return state - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.__setstate__ def __setstate__(self, d): self.__dict__ = d self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model.LoadFromSerializedProto(self.sp_model_proto) @property - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.vocab_size def vocab_size(self): """Returns vocab size""" return self.sp_model.get_piece_size() - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab def get_vocab(self): """Returns vocab as a dict""" vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Args: + text: TextInput + Simply calls PreTrainedTokenizer's method + """ + return super().tokenize(text, **kwargs) + def _tokenize(self, text, **kwargs): """ + Args: + text: TextInput Returns a tokenized string. The Gemma tokenizer never adds a prefix space. """ return self.sp_model.encode(text, out_type=str) - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.sp_model.piece_to_id(token) - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" token = self.sp_model.IdToPiece(index) return token - def _decode( - self, - token_ids: List[int], - skip_special_tokens: bool = False, - spaces_between_special_tokens: bool = False, - **kwargs, - ) -> str: - sub_texts = [] - current_sub_text = [] - for ids in token_ids: - if skip_special_tokens and ids in self.all_special_ids: - continue - if ids in self._added_tokens_decoder: - if current_sub_text: - sub_texts.append(self.sp_model.decode(current_sub_text)) - sub_texts.append(self._added_tokens_decoder[ids].content) - current_sub_text = [] - else: - current_sub_text.append(ids) - if current_sub_text: - sub_texts.append(self.sp_model.decode(current_sub_text)) - - if spaces_between_special_tokens: - sub_texts = " ".join(sub_texts) - else: - sub_texts = "".join(sub_texts) - - return sub_texts.replace(SPIECE_UNDERLINE, " ") - def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" current_sub_tokens = [] @@ -214,7 +192,6 @@ class GemmaTokenizer(PreTrainedTokenizer): out_string += self.sp_model.decode(current_sub_tokens) return out_string - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: """ Save the vocabulary and special tokens file to a directory. @@ -242,7 +219,6 @@ class GemmaTokenizer(PreTrainedTokenizer): return (out_vocab_file,) - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): bos_token_id = [self.bos_token_id] if self.add_bos_token else [] eos_token_id = [self.eos_token_id] if self.add_eos_token else [] @@ -254,7 +230,6 @@ class GemmaTokenizer(PreTrainedTokenizer): return output - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: @@ -292,7 +267,6 @@ class GemmaTokenizer(PreTrainedTokenizer): + eos_token_id ) - # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: @@ -325,3 +299,35 @@ class GemmaTokenizer(PreTrainedTokenizer): output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) return output + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = False, + **kwargs, + ) -> str: + sub_texts = [] + current_sub_text = [] + for ids in token_ids: + if skip_special_tokens and ids in self.all_special_ids: + continue + if ids in self._added_tokens_decoder: + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + sub_texts.append(self._added_tokens_decoder[ids].content) + current_sub_text = [] + else: + current_sub_text.append(ids) + if current_sub_text: + sub_texts.append(self.sp_model.decode(current_sub_text)) + + if spaces_between_special_tokens: + sub_texts = " ".join(sub_texts) + else: + sub_texts = "".join(sub_texts) + + return sub_texts.replace(SPIECE_UNDERLINE, " ") + + +__all__ = ["GemmaTokenizer"] diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 0808aa58b85..c3a2c7add30 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -58,7 +58,6 @@ logger = logging.get_logger(__name__) @dataclass -# Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlipVideo class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput): """ Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`]. @@ -91,7 +90,6 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput): ) -# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlipVideo class InstructBlipVideoVisionEmbeddings(nn.Module): def __init__(self, config: InstructBlipVideoVisionConfig): super().__init__() @@ -166,7 +164,6 @@ class InstructBlipVideoVisionEmbeddings(nn.Module): return embeddings -# Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlipVideo class InstructBlipVideoAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -248,7 +245,6 @@ class InstructBlipVideoAttention(nn.Module): return outputs -# Copied from transformers.models.blip.modeling_blip.BlipMLP class InstructBlipVideoMLP(nn.Module): def __init__(self, config): super().__init__() @@ -264,7 +260,6 @@ class InstructBlipVideoMLP(nn.Module): return hidden_states -# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlipVideo class InstructBlipVideoEncoderLayer(nn.Module): def __init__(self, config: InstructBlipVideoConfig): super().__init__() @@ -330,7 +325,6 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): ] _keep_in_fp32_modules = [] - # Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlipVideo def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_range @@ -450,7 +444,6 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" """ -# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlipVideo class InstructBlipVideoEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a @@ -537,7 +530,6 @@ class InstructBlipVideoEncoder(nn.Module): ) -# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlipVideo, BLIP->INSTRUCTBLIPVIDEO class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel): main_input_name = "pixel_values" config_class = InstructBlipVideoVisionConfig @@ -738,7 +730,6 @@ class InstructBlipVideoQFormerMultiHeadAttention(nn.Module): return outputs -# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipVideoQFormer class InstructBlipVideoQFormerSelfOutput(nn.Module): def __init__(self, config): super().__init__() @@ -753,7 +744,6 @@ class InstructBlipVideoQFormerSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlipVideo class InstructBlipVideoQFormerAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() @@ -803,7 +793,6 @@ class InstructBlipVideoQFormerAttention(nn.Module): return outputs -# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipVideoQFormer class InstructBlipVideoQFormerIntermediate(nn.Module): def __init__(self, config): super().__init__() @@ -819,7 +808,6 @@ class InstructBlipVideoQFormerIntermediate(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipVideoQFormer class InstructBlipVideoQFormerOutput(nn.Module): def __init__(self, config): super().__init__() @@ -937,7 +925,6 @@ class InstructBlipVideoQFormerLayer(nn.Module): return layer_output -# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlipVideo class InstructBlipVideoQFormerEncoder(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py index cc03c1470ee..8e99e4eef59 100644 --- a/src/transformers/models/llama/tokenization_llama.py +++ b/src/transformers/models/llama/tokenization_llama.py @@ -43,14 +43,12 @@ SPIECE_UNDERLINE = "▁" B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" -# fmt: off DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ -correct. If you don't know the answer to a question, please don't share false information.""" -# fmt: on +correct. If you don't know the answer to a question, please don't share false information.""" # fmt: skip class LlamaTokenizer(PreTrainedTokenizer): diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 74d0145e604..ea1114df7c2 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -235,7 +235,6 @@ class LlavaNextVideoPooler(nn.Module): return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() -# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNextVideo class LlavaNextVideoMultiModalProjector(nn.Module): def __init__(self, config: LlavaNextVideoConfig): super().__init__() @@ -272,7 +271,6 @@ LLAVA_NEXT_VIDEO_START_DOCSTRING = r""" "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAVA_NEXT_VIDEO_START_DOCSTRING, ) -# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNextVideo,llava->llava_next_video class LlavaNextVideoPreTrainedModel(PreTrainedModel): config_class = LlavaNextVideoConfig base_model_prefix = "model" @@ -426,35 +424,27 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene raise ValueError(f"{padding_side} is not `left` or `right`.") self._padding_side = padding_side - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings def get_input_embeddings(self): return self.language_model.get_input_embeddings() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings def get_output_embeddings(self): return self.language_model.get_output_embeddings() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder def set_decoder(self, decoder): self.language_model.set_decoder(decoder) - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder def get_decoder(self): return self.language_model.get_decoder() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights def tie_weights(self): return self.language_model.tie_weights() - # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) # update vocab size diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 8e226d92a10..eb606208bf7 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -25,8 +25,8 @@ from torch import Tensor, nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index aa710ad9526..da44913e747 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -25,8 +25,8 @@ from torch import Tensor, nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...deepspeed import is_deepspeed_zero3_enabled from ...generation import GenerationMixin +from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 1bfc1230a91..cc3089da3f3 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -16,7 +16,8 @@ import argparse import glob import importlib import re -from typing import Dict +from collections import defaultdict +from typing import Dict, List, Set import libcst as cst from check_copies import run_ruff @@ -113,7 +114,11 @@ class ClassFinder(CSTVisitor): if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() ): - self.assignments[node.body[0].targets[0].target.value] = node + if hasattr(node.body[0].targets[0].target, "value"): + self.assignments[node.body[0].targets[0].target.value] = node + else: + for idx, target in enumerate(list(node.body[0].targets[0].target.elements)): + self.assignments[target.value.value] = node.body[0].value.elements[idx].value if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): self.imports[node.body[0].names] = node @@ -217,11 +222,21 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer): return compiled_regex.sub(replace, text) + def convert_to_camelcase(self, text): + # Regex pattern to match consecutive uppercase letters and lowercase the first set + result = re.sub(r"^[A-Z]+(?=[A-Z][a-z])", lambda m: m.group(0).capitalize(), text, count=1) + return result + @m.leave(m.Name() | m.SimpleString() | m.Comment()) def replace_name(self, original_node, updated_node): + if re.findall(r"# Copied from", updated_node.value): + return cst.RemoveFromParent() update = self.preserve_case_replace(updated_node.value) return updated_node.with_changes(value=update) + def leave_ClassDef(self, original_node, updated_node): + return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value))) + def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None): """Helper function to rename and then parse a source file using the ClassFinder""" @@ -251,6 +266,63 @@ def SUPER_CALL_NODE(func_name): return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))) +def is_call_to_super(node, func_name): + return m.matches( + node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]) + ) + + +# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method +class ReplaceMethodCallTransformer(cst.CSTTransformer): + def __init__(self, all_bases: Set[str]): + self.all_bases = all_bases + + def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode: + # Handle ClassB.call_to_method + if ( + isinstance(original_node.value, cst.Name) + and original_node.value.value in self.all_bases + and isinstance(original_node.attr, cst.Name) + ): + # Replace with super().call_to_method + return updated_node.with_changes( + value=cst.Call(cst.Name("super")), + ) + # Handle ClassB().call_to_method + elif ( + isinstance(original_node.value, cst.Call) + and isinstance(original_node.value.func, cst.Name) + and original_node.value.func.value in self.all_bases + and isinstance(original_node.attr, cst.Name) + ): + # Replace with super().call_to_method + return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super")))) + return updated_node + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: + # Check if the function being called is of the form ClassB().func_a or ClassB.func_a + if isinstance(original_node.func, cst.Attribute) and ( + # Match ClassB().func_a(...) + ( + isinstance(original_node.func.value, cst.Call) + and isinstance(original_node.func.value.func, cst.Name) + and original_node.func.value.func.value in self.all_bases + ) + or + # Match ClassB.func_a(...) + (isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases) + ): + # Check if the first argument is 'self', and remove it + if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")): + # Create the new argument list without 'self' + new_args = updated_node.args[1:] + else: + new_args = updated_node.args + + return updated_node.with_changes(args=new_args) + return updated_node + + def get_docstring_indent(docstring): # Match the first line after the opening triple quotes match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring) @@ -263,7 +335,7 @@ def get_docstring_indent(docstring): def merge_docstrings(original_docstring, updated_docstring): # indent_level = get_docstring_indent(updated_docstring) original_level = get_docstring_indent(original_docstring) - if " Args:\n " not in updated_docstring: + if not re.findall(r"\n\s*Args:\n", updated_docstring): # Split the docstring at the example section, assuming `"""` is used to define the docstring parts = original_docstring.split("```") if "```" in updated_docstring and len(parts) > 1: @@ -292,13 +364,15 @@ def merge_docstrings(original_docstring, updated_docstring): class SuperTransformer(cst.CSTTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name=""): + def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None): self.python_module = python_module self.original_methods = original_methods self.updated_methods = updated_methods self.all_assign_target = {} self.deleted_targets = {} # child node can delete some arguments self.class_name = class_name + self.all_bases = all_bases or [] + self.transformer = ReplaceMethodCallTransformer(set(self.all_bases)) def update_body(self, existing_body, new_statements): """ @@ -356,18 +430,14 @@ class SuperTransformer(cst.CSTTransformer): parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE) new_body = [] has_super_call = False - for idx, expr in enumerate(node.body): - if m.matches( - expr, - m.SimpleStatementLine( - body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))] - ), - ): - if idx != 0 and func_name == "__init__": - raise ValueError(f"The call to super() in {self.class_name} should be at the top of the init") - new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body)) + + for expr in node.body: + if is_call_to_super(expr, func_name): has_super_call = True - elif m.matches(expr, DOCSTRING_NODE): + new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body)) + else: + expr = expr.visit(self.transformer) + if m.matches(expr, DOCSTRING_NODE): self.has_docstring = True if parent_has_docstring: # actually here we ought to de-duplicate? original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value @@ -406,15 +476,17 @@ class SuperTransformer(cst.CSTTransformer): return updated_node -def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str): +def replace_call_to_super( + class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str] +): """ Given the `class_name`, the `updated_node`'s call to super are unpacked. | ```python | | ```python | class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module): | def __init__(self): | | def __init__(self): - Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2 - | super().__init__() | | super().__init__(config) + Going from: | super().__init__() | to: | super().__init__(config) + | self.dropout = 0.2 | | self.dropout = 0.2 | ``` | | self.padding_idx = config.pad_token_id | self.vocab_size = config.vocab_size | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) @@ -453,7 +525,14 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, new_params = new_params.with_changes( params=list(parent_params.values()), star_kwarg=func.params.star_kwarg ) - func = func.with_changes(body=updated_methods[name].body, params=new_params) + if not re.match( + r"\ndef .*\(.*\):\n raise.*Error\(.*", + class_finder.python_module.code_for_node(updated_methods[name]), + ): + func = func.with_changes(body=updated_methods[name].body, params=new_params) + else: + continue + if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): target = class_finder.python_module.code_for_node(func.body[0].targets[0]) assign_targets[target] = func @@ -492,7 +571,7 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, temp_module = cst.Module(body=[result_node]) new_module = MetadataWrapper(temp_module) new_replacement_class = new_module.visit( - SuperTransformer(temp_module, original_methods, updated_methods, class_name) + SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases) ) new_replacement_body = new_replacement_class.body[0].body # get the indented block @@ -508,6 +587,31 @@ TYPE_TO_FILE_TYPE = { } +def get_new_part(class_name, base_class): + """ + When `MyClassNameAttention` inherits from `MistralAttention`, we need + to process the name to properly find dependencies. + + Here we take what is the same (Attention) and what is different + when finding the dependencies. + """ + common_suffix_len = 0 + for i in range(1, min(len(class_name), len(base_class)) + 1): + if class_name[-i] == base_class[-i]: + common_suffix_len += 1 + else: + break + + if common_suffix_len > 0: + new_part = class_name[:-common_suffix_len] + else: + new_part = class_name + + # Convert the remaining new part to snake_case + snake_case = re.sub(r"(? None: """When visiting imports from `transformers.models.xxx` we need to: @@ -630,13 +735,33 @@ class ModularConverterTransformer(CSTTransformer): self.given_new_name, ) visited_module[super_file_name] = class_finder + list_dependencies = { + dep: class_finder.class_start_line.get(dep, 1000) + for dep in class_finder.class_dependency_mapping.get(class_name, []) + } else: # we are re-using the previously parsed data class_finder = visited_module[super_file_name] - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } + list_dependencies = { + dep: class_finder.class_start_line.get(dep, 1000) + for dep in class_finder.class_dependency_mapping.get(class_name, []) + } + if list_dependencies == []: + # so, maybe standard renaming did not work (the class name is different) + # we try with another renaming pattern + potential_given_name = get_new_part(class_name, super_class) + del visited_module[super_file_name] + class_finder = find_classes_in_file( + self.transformers_imports[super_file_name], + model_name, + potential_given_name, + self.model_name, + potential_given_name, + ) + list_dependencies = { + dep: class_finder.class_start_line.get(dep, 1000) + for dep in class_finder.class_dependency_mapping.get(class_name, []) + } list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) start_insert_idx = self.global_scope_index @@ -668,10 +793,12 @@ class ModularConverterTransformer(CSTTransformer): self.inserted_deps.append(dependency) if len(list_dependencies) > 0: - updated_node = replace_call_to_super(class_finder, updated_node, class_name) + updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases) else: raise ValueError( - f"Unable to find dependencies for {super_class} in {super_file_name}. Here are the dependencies found: {class_finder.class_dependency_mapping}. (The automatic renaming might have gone wrong!)" + f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})" + f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}." + f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`" ) # Now, if a class was defined without parents, we look for the name @@ -679,8 +806,10 @@ class ModularConverterTransformer(CSTTransformer): match = re.search(rf"({match_pattern})$", class_name) if match: key = TYPE_TO_FILE_TYPE[match.group(1)] + self.class_to_file_type[class_name] = key self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} else: + self.class_to_file_type[class_name] = "modeling" self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} return updated_node @@ -690,14 +819,37 @@ class ModularConverterTransformer(CSTTransformer): self.all_definitions[node.name.value] = node return node + def visit_Assign(self, node: cst.Assign) -> None: + # Check if the assignment target is '__all__' + if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__": + if isinstance(node.value, cst.List): + # Extract the elements from the list + all_all_to_add = defaultdict(list) + for elt in node.value.elements: + if isinstance(elt.value, cst.SimpleString): + # Remove quotes and add the string to the elements list + class_name = elt.value.value + file = self.class_to_file_type[ + elt.value.evaluated_value + ] # evaluated value give the content of the string + all_all_to_add[file] += [class_name] + for f_type, new_alls in all_all_to_add.items(): + updated_node = node.with_changes( + value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) + ) + self.files[f_type][class_name] = { + "insert_idx": self.global_scope_index + 100, + "node": updated_node, + } + def leave_If(self, original_node, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): full_statement = self.python_module.code_for_node(original_node.test) if re.search(r"[\s\S]*is_.*available", full_statement): self.all_safe_imports.append(node) - elif full_statement not in self.new_body: - self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node} + elif full_statement not in self.all_imports: + logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}") return node def leave_Module(self, original_node: cst.Assign, node): @@ -764,7 +916,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["examples/modular-transformers/modular_dummy.py"], + default=["src/transformers/models/gemma/modular_gemma.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index 9eb43e7b90a..232eed95b9d 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -373,7 +373,6 @@ src/transformers/data/processors/squad.py src/transformers/data/processors/utils.py src/transformers/data/processors/xnli.py src/transformers/debug_utils.py -src/transformers/deepspeed.py src/transformers/dependency_versions_check.py src/transformers/dependency_versions_table.py src/transformers/dynamic_module_utils.py