Add support for __all__ and potentilly deleting functions (#33859)

* Add support for __all__ and potentailly deleting functions

* updates

* update

* nits

* remove dummies

* fix warning

* fixup

* style

* update

* fixup

* skip copied from when # skip

* remove log

* bring dummies back

* fixup

* remove copied from

* fixup

* remove warnings from `make fix-copies`

* fix doc issues

* nits

* Better error message !

* add support for more flexible naming!

* style

* breaking style?

* fix super() renaming issues

* del not needed when you don't call super().__init__()

* style

* no more fmt on :)

* properly remove `self`

* fixup

* fix

* doc nits

* add some doc 🫡
This commit is contained in:
Arthur 2024-10-08 10:19:17 +02:00 committed by GitHub
parent bead0fa8dc
commit a3add29097
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 477 additions and 149 deletions

View File

@ -119,3 +119,59 @@ Additionally, you may find a list of examples here:
## What it is not ## What it is not
It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual. It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual.
## Advanced usage
### Removing attributes and functions
To remove attributes that are not used in your modular model, and that you don't want to see in the unravelled modeling:
```python
class GemmaModel(LlamaModel): | class GemmaModel(PreTrainedModel):
def __init__(self, config): | def __init__(self, config):
super().__init__(self, eos_token) | super().__init__(config)
del self.embed_tokens | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
|
| self.layers = nn.ModuleList(
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
| )
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
| self.rotary_emb = LlamaRotaryEmbedding(config=config)
| self.gradient_checkpointing = False
|
| # Initialize weights and apply final processing
| self.post_init()
```
If you check the original `LlamaModel`, it has a `embed_tokens` which was removed here (as you would expect!)
Removing a function is pretty similar, you just need to write it with a `raise ValueError("")` to mimick the behaviour you actually want when you remove a parent function in python.
```python
class GemmaTokenizer(LlamaTokenizer):
...
def get_spm_processor(self):
raise AttributeError("Not needed for Gemma")
def unk_token_length(self):
raise AttributeError("Not needed for Gemma")
```
### Calling `super()`
We recently shipped a few features that allow you to go from:
```python
class GemmaTokenizer(LlamaTokenizer, PretrainedTokenizerFast): | class GemmaModel(nn.Module):
def __init__(self, eos_token="</s>"): | def __init__(self):
eos_token = AddedToken(eos_token) | eos_token = AddedToken(eos_token)
PretrainedTokenizerFast.__init__(self, eos_token) | super().__init__(eos_token)
```
This is useful want you **don't** want to unravel the call to `super()`, and you want to differentiate which super init call you are doing!
### Special naming
We now also support special cases like
```python
class GemmaVisionModel(CLIPModel):
pass
```
where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models

View File

@ -4,7 +4,6 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this # modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union

View File

@ -115,7 +115,6 @@ _import_structure = {
"data.metrics": [], "data.metrics": [],
"data.processors": [], "data.processors": [],
"debug_utils": [], "debug_utils": [],
"deepspeed": [],
"dependency_versions_check": [], "dependency_versions_check": [],
"dependency_versions_table": [], "dependency_versions_table": [],
"dynamic_module_utils": [], "dynamic_module_utils": [],

View File

@ -1,41 +0,0 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integration with Deepspeed - kept for backward compatiblity, if you plan to make any edit, make sure to modify the file
in `integrations/deepspeed` instead.
Check: https://github.com/huggingface/transformers/pull/25599
"""
import warnings
warnings.warn(
"transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations",
FutureWarning,
)
# Backward compatibility imports, to make sure all those objects can be found in integrations/deepspeed
from .integrations.deepspeed import ( # noqa
HfDeepSpeedConfig,
HfTrainerDeepSpeedConfig,
deepspeed_config,
deepspeed_init,
deepspeed_load_checkpoint,
deepspeed_optim_sched,
is_deepspeed_available,
is_deepspeed_zero3_enabled,
set_hf_deepspeed_config,
unset_hf_deepspeed_config,
)

View File

@ -143,3 +143,6 @@ class GemmaConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
__all__ = ["GemmaConfig"]

View File

@ -1314,3 +1314,6 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
__all__ = ["GemmaModel", "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification"]

View File

@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import sentencepiece as spm
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
@ -27,6 +28,7 @@ from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
from ..llama.modeling_llama import ( from ..llama.modeling_llama import (
LlamaDecoderLayer, LlamaDecoderLayer,
@ -38,6 +40,15 @@ from ..llama.modeling_llama import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from ..llama.tokenization_llama import LlamaTokenizer
if TYPE_CHECKING:
from ...tokenization_utils_base import TextInput
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
SPIECE_UNDERLINE = ""
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -164,6 +175,162 @@ class GemmaConfig(PretrainedConfig):
) )
class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer):
"""
Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
no padding token in the original model.
Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation.
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not the default system prompt for Gemma should be used.
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to add spaces between special tokens.
"""
def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<bos>",
eos_token="<eos>",
pad_token="<pad>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
use_default_system_prompt=False,
spaces_between_special_tokens=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
PreTrainedTokenizer.__init__(
self,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
use_default_system_prompt=use_default_system_prompt,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)
def get_spm_processor(self):
raise AttributeError("Not needed for Gemma")
def unk_token_length(self):
raise AttributeError("Not needed for Gemma")
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
"""
Args:
text: TextInput
Simply calls PreTrainedTokenizer's method
"""
return PreTrainedTokenizer.tokenize(self, text, **kwargs)
def _tokenize(self, text, **kwargs):
"""
Args:
text: TextInput
Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
"""
return self.sp_model.encode(text, out_type=str)
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
sub_texts = []
current_sub_text = []
for ids in token_ids:
if skip_special_tokens and ids in self.all_special_ids:
continue
if ids in self._added_tokens_decoder:
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
sub_texts.append(self._added_tokens_decoder[ids].content)
current_sub_text = []
else:
current_sub_text.append(ids)
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
if spaces_between_special_tokens:
sub_texts = " ".join(sub_texts)
else:
sub_texts = "".join(sub_texts)
return sub_texts.replace(SPIECE_UNDERLINE, " ")
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self._added_tokens_encoder:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
class GemmaRMSNorm(nn.Module): class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
@ -874,3 +1041,13 @@ class GemmaForTokenClassification(LlamaForTokenClassification):
super().__init__(config) super().__init__(config)
self.model = GemmaModel(config) self.model = GemmaModel(config)
self.post_init() self.post_init()
__all__ = [
"GemmaConfig",
"GemmaTokenizer",
"GemmaModel",
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
]

View File

@ -1,5 +1,12 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from <path_to_modular_file.py>.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8 # coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,9 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for Gemma."""
import os import os
from shutil import copyfile from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@ -26,7 +30,7 @@ from ...utils import logging
if TYPE_CHECKING: if TYPE_CHECKING:
pass from ...tokenization_utils_base import TextInput
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -110,7 +114,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt self.use_default_system_prompt = use_default_system_prompt
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file) self.sp_model.Load(vocab_file)
@ -121,85 +124,60 @@ class GemmaTokenizer(PreTrainedTokenizer):
pad_token=pad_token, pad_token=pad_token,
add_bos_token=add_bos_token, add_bos_token=add_bos_token,
add_eos_token=add_eos_token, add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs, sp_model_kwargs=sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
use_default_system_prompt=use_default_system_prompt, use_default_system_prompt=use_default_system_prompt,
spaces_between_special_tokens=spaces_between_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.__getstate__
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto() state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state return state
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.__setstate__
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__ = d self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto) self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@property @property
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.vocab_size
def vocab_size(self): def vocab_size(self):
"""Returns vocab size""" """Returns vocab size"""
return self.sp_model.get_piece_size() return self.sp_model.get_piece_size()
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
def get_vocab(self): def get_vocab(self):
"""Returns vocab as a dict""" """Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder) vocab.update(self.added_tokens_encoder)
return vocab return vocab
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
"""
Args:
text: TextInput
Simply calls PreTrainedTokenizer's method
"""
return super().tokenize(text, **kwargs)
def _tokenize(self, text, **kwargs): def _tokenize(self, text, **kwargs):
""" """
Args:
text: TextInput
Returns a tokenized string. The Gemma tokenizer never adds a prefix space. Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
""" """
return self.sp_model.encode(text, out_type=str) return self.sp_model.encode(text, out_type=str)
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab.""" """Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token) return self.sp_model.piece_to_id(token)
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index) token = self.sp_model.IdToPiece(index)
return token return token
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
sub_texts = []
current_sub_text = []
for ids in token_ids:
if skip_special_tokens and ids in self.all_special_ids:
continue
if ids in self._added_tokens_decoder:
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
sub_texts.append(self._added_tokens_decoder[ids].content)
current_sub_text = []
else:
current_sub_text.append(ids)
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
if spaces_between_special_tokens:
sub_texts = " ".join(sub_texts)
else:
sub_texts = "".join(sub_texts)
return sub_texts.replace(SPIECE_UNDERLINE, " ")
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = [] current_sub_tokens = []
@ -214,7 +192,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
out_string += self.sp_model.decode(current_sub_tokens) out_string += self.sp_model.decode(current_sub_tokens)
return out_string return out_string
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
""" """
Save the vocabulary and special tokens file to a directory. Save the vocabulary and special tokens file to a directory.
@ -242,7 +219,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
return (out_vocab_file,) return (out_vocab_file,)
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else [] bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else [] eos_token_id = [self.eos_token_id] if self.add_eos_token else []
@ -254,7 +230,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
return output return output
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
def get_special_tokens_mask( def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]: ) -> List[int]:
@ -292,7 +267,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
+ eos_token_id + eos_token_id
) )
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
def create_token_type_ids_from_sequences( def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]: ) -> List[int]:
@ -325,3 +299,35 @@ class GemmaTokenizer(PreTrainedTokenizer):
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
return output return output
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
sub_texts = []
current_sub_text = []
for ids in token_ids:
if skip_special_tokens and ids in self.all_special_ids:
continue
if ids in self._added_tokens_decoder:
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
sub_texts.append(self._added_tokens_decoder[ids].content)
current_sub_text = []
else:
current_sub_text.append(ids)
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
if spaces_between_special_tokens:
sub_texts = " ".join(sub_texts)
else:
sub_texts = "".join(sub_texts)
return sub_texts.replace(SPIECE_UNDERLINE, " ")
__all__ = ["GemmaTokenizer"]

View File

@ -58,7 +58,6 @@ logger = logging.get_logger(__name__)
@dataclass @dataclass
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlipVideo
class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput): class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
""" """
Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`]. Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`].
@ -91,7 +90,6 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
) )
# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlipVideo
class InstructBlipVideoVisionEmbeddings(nn.Module): class InstructBlipVideoVisionEmbeddings(nn.Module):
def __init__(self, config: InstructBlipVideoVisionConfig): def __init__(self, config: InstructBlipVideoVisionConfig):
super().__init__() super().__init__()
@ -166,7 +164,6 @@ class InstructBlipVideoVisionEmbeddings(nn.Module):
return embeddings return embeddings
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlipVideo
class InstructBlipVideoAttention(nn.Module): class InstructBlipVideoAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
@ -248,7 +245,6 @@ class InstructBlipVideoAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.blip.modeling_blip.BlipMLP
class InstructBlipVideoMLP(nn.Module): class InstructBlipVideoMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -264,7 +260,6 @@ class InstructBlipVideoMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlipVideo
class InstructBlipVideoEncoderLayer(nn.Module): class InstructBlipVideoEncoderLayer(nn.Module):
def __init__(self, config: InstructBlipVideoConfig): def __init__(self, config: InstructBlipVideoConfig):
super().__init__() super().__init__()
@ -330,7 +325,6 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
] ]
_keep_in_fp32_modules = [] _keep_in_fp32_modules = []
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlipVideo
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
factor = self.config.initializer_range factor = self.config.initializer_range
@ -450,7 +444,6 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlipVideo
class InstructBlipVideoEncoder(nn.Module): class InstructBlipVideoEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
@ -537,7 +530,6 @@ class InstructBlipVideoEncoder(nn.Module):
) )
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlipVideo, BLIP->INSTRUCTBLIPVIDEO
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel): class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
config_class = InstructBlipVideoVisionConfig config_class = InstructBlipVideoVisionConfig
@ -738,7 +730,6 @@ class InstructBlipVideoQFormerMultiHeadAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipVideoQFormer
class InstructBlipVideoQFormerSelfOutput(nn.Module): class InstructBlipVideoQFormerSelfOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -753,7 +744,6 @@ class InstructBlipVideoQFormerSelfOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlipVideo
class InstructBlipVideoQFormerAttention(nn.Module): class InstructBlipVideoQFormerAttention(nn.Module):
def __init__(self, config, is_cross_attention=False): def __init__(self, config, is_cross_attention=False):
super().__init__() super().__init__()
@ -803,7 +793,6 @@ class InstructBlipVideoQFormerAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipVideoQFormer
class InstructBlipVideoQFormerIntermediate(nn.Module): class InstructBlipVideoQFormerIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -819,7 +808,6 @@ class InstructBlipVideoQFormerIntermediate(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipVideoQFormer
class InstructBlipVideoQFormerOutput(nn.Module): class InstructBlipVideoQFormerOutput(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -937,7 +925,6 @@ class InstructBlipVideoQFormerLayer(nn.Module):
return layer_output return layer_output
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlipVideo
class InstructBlipVideoQFormerEncoder(nn.Module): class InstructBlipVideoQFormerEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()

View File

@ -43,14 +43,12 @@ SPIECE_UNDERLINE = "▁"
B_INST, E_INST = "[INST]", "[/INST]" B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature. that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information.""" correct. If you don't know the answer to a question, please don't share false information.""" # fmt: skip
# fmt: on
class LlamaTokenizer(PreTrainedTokenizer): class LlamaTokenizer(PreTrainedTokenizer):

View File

@ -235,7 +235,6 @@ class LlavaNextVideoPooler(nn.Module):
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNextVideo
class LlavaNextVideoMultiModalProjector(nn.Module): class LlavaNextVideoMultiModalProjector(nn.Module):
def __init__(self, config: LlavaNextVideoConfig): def __init__(self, config: LlavaNextVideoConfig):
super().__init__() super().__init__()
@ -272,7 +271,6 @@ LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.", "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_NEXT_VIDEO_START_DOCSTRING, LLAVA_NEXT_VIDEO_START_DOCSTRING,
) )
# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNextVideo,llava->llava_next_video
class LlavaNextVideoPreTrainedModel(PreTrainedModel): class LlavaNextVideoPreTrainedModel(PreTrainedModel):
config_class = LlavaNextVideoConfig config_class = LlavaNextVideoConfig
base_model_prefix = "model" base_model_prefix = "model"
@ -426,35 +424,27 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
raise ValueError(f"{padding_side} is not `left` or `right`.") raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side self._padding_side = padding_side
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self): def get_input_embeddings(self):
return self.language_model.get_input_embeddings() return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value) self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self): def get_output_embeddings(self):
return self.language_model.get_output_embeddings() return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings) self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
def set_decoder(self, decoder): def set_decoder(self, decoder):
self.language_model.set_decoder(decoder) self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
def get_decoder(self): def get_decoder(self):
return self.language_model.get_decoder() return self.language_model.get_decoder()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
def tie_weights(self): def tie_weights(self):
return self.language_model.tie_weights() return self.language_model.tie_weights()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size # update vocab size

View File

@ -25,8 +25,8 @@ from torch import Tensor, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,

View File

@ -25,8 +25,8 @@ from torch import Tensor, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,

View File

@ -16,7 +16,8 @@ import argparse
import glob import glob
import importlib import importlib
import re import re
from typing import Dict from collections import defaultdict
from typing import Dict, List, Set
import libcst as cst import libcst as cst
from check_copies import run_ruff from check_copies import run_ruff
@ -113,7 +114,11 @@ class ClassFinder(CSTVisitor):
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
): ):
if hasattr(node.body[0].targets[0].target, "value"):
self.assignments[node.body[0].targets[0].target.value] = node self.assignments[node.body[0].targets[0].target.value] = node
else:
for idx, target in enumerate(list(node.body[0].targets[0].target.elements)):
self.assignments[target.value.value] = node.body[0].value.elements[idx].value
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
self.imports[node.body[0].names] = node self.imports[node.body[0].names] = node
@ -217,11 +222,21 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
return compiled_regex.sub(replace, text) return compiled_regex.sub(replace, text)
def convert_to_camelcase(self, text):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(r"^[A-Z]+(?=[A-Z][a-z])", lambda m: m.group(0).capitalize(), text, count=1)
return result
@m.leave(m.Name() | m.SimpleString() | m.Comment()) @m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node): def replace_name(self, original_node, updated_node):
if re.findall(r"# Copied from", updated_node.value):
return cst.RemoveFromParent()
update = self.preserve_case_replace(updated_node.value) update = self.preserve_case_replace(updated_node.value)
return updated_node.with_changes(value=update) return updated_node.with_changes(value=update)
def leave_ClassDef(self, original_node, updated_node):
return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None): def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
"""Helper function to rename and then parse a source file using the ClassFinder""" """Helper function to rename and then parse a source file using the ClassFinder"""
@ -251,6 +266,63 @@ def SUPER_CALL_NODE(func_name):
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))) return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
def is_call_to_super(node, func_name):
return m.matches(
node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))])
)
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
class ReplaceMethodCallTransformer(cst.CSTTransformer):
def __init__(self, all_bases: Set[str]):
self.all_bases = all_bases
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
if (
isinstance(original_node.value, cst.Name)
and original_node.value.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
):
# Replace with super().call_to_method
return updated_node.with_changes(
value=cst.Call(cst.Name("super")),
)
# Handle ClassB().call_to_method
elif (
isinstance(original_node.value, cst.Call)
and isinstance(original_node.value.func, cst.Name)
and original_node.value.func.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
return updated_node
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if isinstance(original_node.func, cst.Attribute) and (
# Match ClassB().func_a(...)
(
isinstance(original_node.func.value, cst.Call)
and isinstance(original_node.func.value.func, cst.Name)
and original_node.func.value.func.value in self.all_bases
)
or
# Match ClassB.func_a(...)
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
# Create the new argument list without 'self'
new_args = updated_node.args[1:]
else:
new_args = updated_node.args
return updated_node.with_changes(args=new_args)
return updated_node
def get_docstring_indent(docstring): def get_docstring_indent(docstring):
# Match the first line after the opening triple quotes # Match the first line after the opening triple quotes
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring) match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
@ -263,7 +335,7 @@ def get_docstring_indent(docstring):
def merge_docstrings(original_docstring, updated_docstring): def merge_docstrings(original_docstring, updated_docstring):
# indent_level = get_docstring_indent(updated_docstring) # indent_level = get_docstring_indent(updated_docstring)
original_level = get_docstring_indent(original_docstring) original_level = get_docstring_indent(original_docstring)
if " Args:\n " not in updated_docstring: if not re.findall(r"\n\s*Args:\n", updated_docstring):
# Split the docstring at the example section, assuming `"""` is used to define the docstring # Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```") parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 1: if "```" in updated_docstring and len(parts) > 1:
@ -292,13 +364,15 @@ def merge_docstrings(original_docstring, updated_docstring):
class SuperTransformer(cst.CSTTransformer): class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,) METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name=""): def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None):
self.python_module = python_module self.python_module = python_module
self.original_methods = original_methods self.original_methods = original_methods
self.updated_methods = updated_methods self.updated_methods = updated_methods
self.all_assign_target = {} self.all_assign_target = {}
self.deleted_targets = {} # child node can delete some arguments self.deleted_targets = {} # child node can delete some arguments
self.class_name = class_name self.class_name = class_name
self.all_bases = all_bases or []
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))
def update_body(self, existing_body, new_statements): def update_body(self, existing_body, new_statements):
""" """
@ -356,18 +430,14 @@ class SuperTransformer(cst.CSTTransformer):
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE) parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
new_body = [] new_body = []
has_super_call = False has_super_call = False
for idx, expr in enumerate(node.body):
if m.matches( for expr in node.body:
expr, if is_call_to_super(expr, func_name):
m.SimpleStatementLine(
body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]
),
):
if idx != 0 and func_name == "__init__":
raise ValueError(f"The call to super() in {self.class_name} should be at the top of the init")
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
has_super_call = True has_super_call = True
elif m.matches(expr, DOCSTRING_NODE): new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
else:
expr = expr.visit(self.transformer)
if m.matches(expr, DOCSTRING_NODE):
self.has_docstring = True self.has_docstring = True
if parent_has_docstring: # actually here we ought to de-duplicate? if parent_has_docstring: # actually here we ought to de-duplicate?
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
@ -406,15 +476,17 @@ class SuperTransformer(cst.CSTTransformer):
return updated_node return updated_node
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str): def replace_call_to_super(
class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str]
):
""" """
Given the `class_name`, the `updated_node`'s call to super are unpacked. Given the `class_name`, the `updated_node`'s call to super are unpacked.
| ```python | | ```python | ```python | | ```python
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module): | class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
| def __init__(self): | | def __init__(self): | def __init__(self): | | def __init__(self):
Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2 Going from: | super().__init__() | to: | super().__init__(config)
| super().__init__() | | super().__init__(config) | self.dropout = 0.2 | | self.dropout = 0.2
| ``` | | self.padding_idx = config.pad_token_id | ``` | | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size | self.vocab_size = config.vocab_size
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
@ -453,7 +525,14 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
new_params = new_params.with_changes( new_params = new_params.with_changes(
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
) )
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
class_finder.python_module.code_for_node(updated_methods[name]),
):
func = func.with_changes(body=updated_methods[name].body, params=new_params) func = func.with_changes(body=updated_methods[name].body, params=new_params)
else:
continue
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
target = class_finder.python_module.code_for_node(func.body[0].targets[0]) target = class_finder.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func assign_targets[target] = func
@ -492,7 +571,7 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
temp_module = cst.Module(body=[result_node]) temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module) new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit( new_replacement_class = new_module.visit(
SuperTransformer(temp_module, original_methods, updated_methods, class_name) SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases)
) )
new_replacement_body = new_replacement_class.body[0].body # get the indented block new_replacement_body = new_replacement_class.body[0].body # get the indented block
@ -508,6 +587,31 @@ TYPE_TO_FILE_TYPE = {
} }
def get_new_part(class_name, base_class):
"""
When `MyClassNameAttention` inherits from `MistralAttention`, we need
to process the name to properly find dependencies.
Here we take what is the same (Attention) and what is different
when finding the dependencies.
"""
common_suffix_len = 0
for i in range(1, min(len(class_name), len(base_class)) + 1):
if class_name[-i] == base_class[-i]:
common_suffix_len += 1
else:
break
if common_suffix_len > 0:
new_part = class_name[:-common_suffix_len]
else:
new_part = class_name
# Convert the remaining new part to snake_case
snake_case = re.sub(r"(?<!^)(?=[A-Z])", "_", new_part).lower()
return snake_case
class ModularConverterTransformer(CSTTransformer): class ModularConverterTransformer(CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
@ -538,6 +642,7 @@ class ModularConverterTransformer(CSTTransformer):
} }
self.match_patterns = "|".join(self.files.keys()) self.match_patterns = "|".join(self.files.keys())
self.all_definitions = {} self.all_definitions = {}
self.class_to_file_type = {}
def visit_ImportFrom(self, node: cst.ImportFrom) -> None: def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""When visiting imports from `transformers.models.xxx` we need to: """When visiting imports from `transformers.models.xxx` we need to:
@ -630,6 +735,10 @@ class ModularConverterTransformer(CSTTransformer):
self.given_new_name, self.given_new_name,
) )
visited_module[super_file_name] = class_finder visited_module[super_file_name] = class_finder
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
else: # we are re-using the previously parsed data else: # we are re-using the previously parsed data
class_finder = visited_module[super_file_name] class_finder = visited_module[super_file_name]
@ -637,6 +746,22 @@ class ModularConverterTransformer(CSTTransformer):
dep: class_finder.class_start_line.get(dep, 1000) dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, []) for dep in class_finder.class_dependency_mapping.get(class_name, [])
} }
if list_dependencies == []:
# so, maybe standard renaming did not work (the class name is different)
# we try with another renaming pattern
potential_given_name = get_new_part(class_name, super_class)
del visited_module[super_file_name]
class_finder = find_classes_in_file(
self.transformers_imports[super_file_name],
model_name,
potential_given_name,
self.model_name,
potential_given_name,
)
list_dependencies = {
dep: class_finder.class_start_line.get(dep, 1000)
for dep in class_finder.class_dependency_mapping.get(class_name, [])
}
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
start_insert_idx = self.global_scope_index start_insert_idx = self.global_scope_index
@ -668,10 +793,12 @@ class ModularConverterTransformer(CSTTransformer):
self.inserted_deps.append(dependency) self.inserted_deps.append(dependency)
if len(list_dependencies) > 0: if len(list_dependencies) > 0:
updated_node = replace_call_to_super(class_finder, updated_node, class_name) updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases)
else: else:
raise ValueError( raise ValueError(
f"Unable to find dependencies for {super_class} in {super_file_name}. Here are the dependencies found: {class_finder.class_dependency_mapping}. (The automatic renaming might have gone wrong!)" f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
) )
# Now, if a class was defined without parents, we look for the name # Now, if a class was defined without parents, we look for the name
@ -679,8 +806,10 @@ class ModularConverterTransformer(CSTTransformer):
match = re.search(rf"({match_pattern})$", class_name) match = re.search(rf"({match_pattern})$", class_name)
if match: if match:
key = TYPE_TO_FILE_TYPE[match.group(1)] key = TYPE_TO_FILE_TYPE[match.group(1)]
self.class_to_file_type[class_name] = key
self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
else: else:
self.class_to_file_type[class_name] = "modeling"
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
return updated_node return updated_node
@ -690,14 +819,37 @@ class ModularConverterTransformer(CSTTransformer):
self.all_definitions[node.name.value] = node self.all_definitions[node.name.value] = node
return node return node
def visit_Assign(self, node: cst.Assign) -> None:
# Check if the assignment target is '__all__'
if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__":
if isinstance(node.value, cst.List):
# Extract the elements from the list
all_all_to_add = defaultdict(list)
for elt in node.value.elements:
if isinstance(elt.value, cst.SimpleString):
# Remove quotes and add the string to the elements list
class_name = elt.value.value
file = self.class_to_file_type[
elt.value.evaluated_value
] # evaluated value give the content of the string
all_all_to_add[file] += [class_name]
for f_type, new_alls in all_all_to_add.items():
updated_node = node.with_changes(
value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls])
)
self.files[f_type][class_name] = {
"insert_idx": self.global_scope_index + 100,
"node": updated_node,
}
def leave_If(self, original_node, node): def leave_If(self, original_node, node):
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
if m.matches(parent_node, m.Module()): if m.matches(parent_node, m.Module()):
full_statement = self.python_module.code_for_node(original_node.test) full_statement = self.python_module.code_for_node(original_node.test)
if re.search(r"[\s\S]*is_.*available", full_statement): if re.search(r"[\s\S]*is_.*available", full_statement):
self.all_safe_imports.append(node) self.all_safe_imports.append(node)
elif full_statement not in self.new_body: elif full_statement not in self.all_imports:
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node} logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}")
return node return node
def leave_Module(self, original_node: cst.Assign, node): def leave_Module(self, original_node: cst.Assign, node):
@ -764,7 +916,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--files_to_parse", "--files_to_parse",
default=["examples/modular-transformers/modular_dummy.py"], default=["src/transformers/models/gemma/modular_gemma.py"],
nargs="+", nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file", help="A list of `modular_xxxx` files that should be converted to single model file",
) )

View File

@ -373,7 +373,6 @@ src/transformers/data/processors/squad.py
src/transformers/data/processors/utils.py src/transformers/data/processors/utils.py
src/transformers/data/processors/xnli.py src/transformers/data/processors/xnli.py
src/transformers/debug_utils.py src/transformers/debug_utils.py
src/transformers/deepspeed.py
src/transformers/dependency_versions_check.py src/transformers/dependency_versions_check.py
src/transformers/dependency_versions_table.py src/transformers/dependency_versions_table.py
src/transformers/dynamic_module_utils.py src/transformers/dynamic_module_utils.py