mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add support for __all__ and potentilly deleting functions (#33859)
* Add support for __all__ and potentailly deleting functions
* updates
* update
* nits
* remove dummies
* fix warning
* fixup
* style
* update
* fixup
* skip copied from when # skip
* remove log
* bring dummies back
* fixup
* remove copied from
* fixup
* remove warnings from `make fix-copies`
* fix doc issues
* nits
* Better error message !
* add support for more flexible naming!
* style
* breaking style?
* fix super() renaming issues
* del not needed when you don't call super().__init__()
* style
* no more fmt on :)
* properly remove `self`
* fixup
* fix
* doc nits
* add some doc 🫡
This commit is contained in:
parent
bead0fa8dc
commit
a3add29097
@ -118,4 +118,60 @@ Additionally, you may find a list of examples here:
|
|||||||
|
|
||||||
## What it is not
|
## What it is not
|
||||||
|
|
||||||
It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual.
|
It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual.
|
||||||
|
|
||||||
|
|
||||||
|
## Advanced usage
|
||||||
|
|
||||||
|
### Removing attributes and functions
|
||||||
|
To remove attributes that are not used in your modular model, and that you don't want to see in the unravelled modeling:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class GemmaModel(LlamaModel): | class GemmaModel(PreTrainedModel):
|
||||||
|
def __init__(self, config): | def __init__(self, config):
|
||||||
|
super().__init__(self, eos_token) | super().__init__(config)
|
||||||
|
del self.embed_tokens | self.padding_idx = config.pad_token_id
|
||||||
|
| self.vocab_size = config.vocab_size
|
||||||
|
|
|
||||||
|
| self.layers = nn.ModuleList(
|
||||||
|
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
|
| )
|
||||||
|
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
| self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||||
|
| self.gradient_checkpointing = False
|
||||||
|
|
|
||||||
|
| # Initialize weights and apply final processing
|
||||||
|
| self.post_init()
|
||||||
|
```
|
||||||
|
If you check the original `LlamaModel`, it has a `embed_tokens` which was removed here (as you would expect!)
|
||||||
|
|
||||||
|
Removing a function is pretty similar, you just need to write it with a `raise ValueError("")` to mimick the behaviour you actually want when you remove a parent function in python.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class GemmaTokenizer(LlamaTokenizer):
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_spm_processor(self):
|
||||||
|
raise AttributeError("Not needed for Gemma")
|
||||||
|
|
||||||
|
def unk_token_length(self):
|
||||||
|
raise AttributeError("Not needed for Gemma")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Calling `super()`
|
||||||
|
We recently shipped a few features that allow you to go from:
|
||||||
|
```python
|
||||||
|
class GemmaTokenizer(LlamaTokenizer, PretrainedTokenizerFast): | class GemmaModel(nn.Module):
|
||||||
|
def __init__(self, eos_token="</s>"): | def __init__(self):
|
||||||
|
eos_token = AddedToken(eos_token) | eos_token = AddedToken(eos_token)
|
||||||
|
PretrainedTokenizerFast.__init__(self, eos_token) | super().__init__(eos_token)
|
||||||
|
```
|
||||||
|
This is useful want you **don't** want to unravel the call to `super()`, and you want to differentiate which super init call you are doing!
|
||||||
|
|
||||||
|
### Special naming
|
||||||
|
We now also support special cases like
|
||||||
|
```python
|
||||||
|
class GemmaVisionModel(CLIPModel):
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models
|
@ -4,7 +4,6 @@
|
|||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_xxx.py file directly. One of our CI enforces this
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -115,7 +115,6 @@ _import_structure = {
|
|||||||
"data.metrics": [],
|
"data.metrics": [],
|
||||||
"data.processors": [],
|
"data.processors": [],
|
||||||
"debug_utils": [],
|
"debug_utils": [],
|
||||||
"deepspeed": [],
|
|
||||||
"dependency_versions_check": [],
|
"dependency_versions_check": [],
|
||||||
"dependency_versions_table": [],
|
"dependency_versions_table": [],
|
||||||
"dynamic_module_utils": [],
|
"dynamic_module_utils": [],
|
||||||
|
@ -1,41 +0,0 @@
|
|||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Integration with Deepspeed - kept for backward compatiblity, if you plan to make any edit, make sure to modify the file
|
|
||||||
in `integrations/deepspeed` instead.
|
|
||||||
|
|
||||||
Check: https://github.com/huggingface/transformers/pull/25599
|
|
||||||
"""
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Backward compatibility imports, to make sure all those objects can be found in integrations/deepspeed
|
|
||||||
from .integrations.deepspeed import ( # noqa
|
|
||||||
HfDeepSpeedConfig,
|
|
||||||
HfTrainerDeepSpeedConfig,
|
|
||||||
deepspeed_config,
|
|
||||||
deepspeed_init,
|
|
||||||
deepspeed_load_checkpoint,
|
|
||||||
deepspeed_optim_sched,
|
|
||||||
is_deepspeed_available,
|
|
||||||
is_deepspeed_zero3_enabled,
|
|
||||||
set_hf_deepspeed_config,
|
|
||||||
unset_hf_deepspeed_config,
|
|
||||||
)
|
|
@ -143,3 +143,6 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
tie_word_embeddings=tie_word_embeddings,
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["GemmaConfig"]
|
||||||
|
@ -1314,3 +1314,6 @@ class GemmaForTokenClassification(GemmaPreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["GemmaModel", "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification"]
|
||||||
|
@ -14,8 +14,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -27,6 +28,7 @@ from ...configuration_utils import PretrainedConfig
|
|||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
|
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..llama.modeling_llama import (
|
from ..llama.modeling_llama import (
|
||||||
LlamaDecoderLayer,
|
LlamaDecoderLayer,
|
||||||
@ -38,6 +40,15 @@ from ..llama.modeling_llama import (
|
|||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
from ..llama.tokenization_llama import LlamaTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...tokenization_utils_base import TextInput
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
||||||
|
|
||||||
|
SPIECE_UNDERLINE = "▁"
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -164,6 +175,162 @@ class GemmaConfig(PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer):
|
||||||
|
"""
|
||||||
|
Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
|
||||||
|
no padding token in the original model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (`str`):
|
||||||
|
Path to the vocabulary file.
|
||||||
|
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
|
||||||
|
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||||
|
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
|
||||||
|
The end of sequence token.
|
||||||
|
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
|
||||||
|
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
|
||||||
|
attention mechanisms or loss computation.
|
||||||
|
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
|
||||||
|
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||||
|
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
||||||
|
to set:
|
||||||
|
|
||||||
|
- `enable_sampling`: Enable subword regularization.
|
||||||
|
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
||||||
|
|
||||||
|
- `nbest_size = {0,1}`: No sampling is performed.
|
||||||
|
- `nbest_size > 1`: samples from the nbest_size results.
|
||||||
|
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
||||||
|
using forward-filtering-and-backward-sampling algorithm.
|
||||||
|
|
||||||
|
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
||||||
|
BPE-dropout.
|
||||||
|
|
||||||
|
add_bos_token (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to add an `bos_token` at the start of sequences.
|
||||||
|
add_eos_token (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to add an `eos_token` at the end of sequences.
|
||||||
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
|
||||||
|
extra spaces.
|
||||||
|
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the default system prompt for Gemma should be used.
|
||||||
|
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to add spaces between special tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
unk_token="<unk>",
|
||||||
|
bos_token="<bos>",
|
||||||
|
eos_token="<eos>",
|
||||||
|
pad_token="<pad>",
|
||||||
|
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
add_bos_token=True,
|
||||||
|
add_eos_token=False,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
use_default_system_prompt=False,
|
||||||
|
spaces_between_special_tokens=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||||
|
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
|
||||||
|
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
|
||||||
|
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
|
||||||
|
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
|
||||||
|
|
||||||
|
self.vocab_file = vocab_file
|
||||||
|
self.add_bos_token = add_bos_token
|
||||||
|
self.add_eos_token = add_eos_token
|
||||||
|
self.use_default_system_prompt = use_default_system_prompt
|
||||||
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
|
self.sp_model.Load(vocab_file)
|
||||||
|
|
||||||
|
PreTrainedTokenizer.__init__(
|
||||||
|
self,
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
add_bos_token=add_bos_token,
|
||||||
|
add_eos_token=add_eos_token,
|
||||||
|
sp_model_kwargs=sp_model_kwargs,
|
||||||
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
|
use_default_system_prompt=use_default_system_prompt,
|
||||||
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_spm_processor(self):
|
||||||
|
raise AttributeError("Not needed for Gemma")
|
||||||
|
|
||||||
|
def unk_token_length(self):
|
||||||
|
raise AttributeError("Not needed for Gemma")
|
||||||
|
|
||||||
|
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text: TextInput
|
||||||
|
Simply calls PreTrainedTokenizer's method
|
||||||
|
"""
|
||||||
|
return PreTrainedTokenizer.tokenize(self, text, **kwargs)
|
||||||
|
|
||||||
|
def _tokenize(self, text, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text: TextInput
|
||||||
|
Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
|
||||||
|
"""
|
||||||
|
return self.sp_model.encode(text, out_type=str)
|
||||||
|
|
||||||
|
def _decode(
|
||||||
|
self,
|
||||||
|
token_ids: List[int],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
spaces_between_special_tokens: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
sub_texts = []
|
||||||
|
current_sub_text = []
|
||||||
|
for ids in token_ids:
|
||||||
|
if skip_special_tokens and ids in self.all_special_ids:
|
||||||
|
continue
|
||||||
|
if ids in self._added_tokens_decoder:
|
||||||
|
if current_sub_text:
|
||||||
|
sub_texts.append(self.sp_model.decode(current_sub_text))
|
||||||
|
sub_texts.append(self._added_tokens_decoder[ids].content)
|
||||||
|
current_sub_text = []
|
||||||
|
else:
|
||||||
|
current_sub_text.append(ids)
|
||||||
|
if current_sub_text:
|
||||||
|
sub_texts.append(self.sp_model.decode(current_sub_text))
|
||||||
|
|
||||||
|
if spaces_between_special_tokens:
|
||||||
|
sub_texts = " ".join(sub_texts)
|
||||||
|
else:
|
||||||
|
sub_texts = "".join(sub_texts)
|
||||||
|
|
||||||
|
return sub_texts.replace(SPIECE_UNDERLINE, " ")
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens):
|
||||||
|
"""Converts a sequence of tokens (string) in a single string."""
|
||||||
|
current_sub_tokens = []
|
||||||
|
out_string = ""
|
||||||
|
for token in tokens:
|
||||||
|
# make sure that special tokens are not decoded using sentencepiece model
|
||||||
|
if token in self._added_tokens_encoder:
|
||||||
|
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||||
|
current_sub_tokens = []
|
||||||
|
else:
|
||||||
|
current_sub_tokens.append(token)
|
||||||
|
out_string += self.sp_model.decode(current_sub_tokens)
|
||||||
|
return out_string
|
||||||
|
|
||||||
|
|
||||||
class GemmaRMSNorm(nn.Module):
|
class GemmaRMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -874,3 +1041,13 @@ class GemmaForTokenClassification(LlamaForTokenClassification):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model = GemmaModel(config)
|
self.model = GemmaModel(config)
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GemmaConfig",
|
||||||
|
"GemmaTokenizer",
|
||||||
|
"GemmaModel",
|
||||||
|
"GemmaForCausalLM",
|
||||||
|
"GemmaForSequenceClassification",
|
||||||
|
"GemmaForTokenClassification",
|
||||||
|
]
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from <path_to_modular_file.py>.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_xxx.py file directly. One of our CI enforces this
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,9 +19,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""Tokenization classes for Gemma."""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
@ -26,7 +30,7 @@ from ...utils import logging
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
from ...tokenization_utils_base import TextInput
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -110,7 +114,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
|||||||
self.add_bos_token = add_bos_token
|
self.add_bos_token = add_bos_token
|
||||||
self.add_eos_token = add_eos_token
|
self.add_eos_token = add_eos_token
|
||||||
self.use_default_system_prompt = use_default_system_prompt
|
self.use_default_system_prompt = use_default_system_prompt
|
||||||
|
|
||||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.Load(vocab_file)
|
self.sp_model.Load(vocab_file)
|
||||||
|
|
||||||
@ -121,85 +124,60 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
|||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
add_bos_token=add_bos_token,
|
add_bos_token=add_bos_token,
|
||||||
add_eos_token=add_eos_token,
|
add_eos_token=add_eos_token,
|
||||||
sp_model_kwargs=self.sp_model_kwargs,
|
sp_model_kwargs=sp_model_kwargs,
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
use_default_system_prompt=use_default_system_prompt,
|
use_default_system_prompt=use_default_system_prompt,
|
||||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.__getstate__
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
state["sp_model"] = None
|
state["sp_model"] = None
|
||||||
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
||||||
return state
|
return state
|
||||||
|
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.__setstate__
|
|
||||||
def __setstate__(self, d):
|
def __setstate__(self, d):
|
||||||
self.__dict__ = d
|
self.__dict__ = d
|
||||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.vocab_size
|
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
"""Returns vocab size"""
|
"""Returns vocab size"""
|
||||||
return self.sp_model.get_piece_size()
|
return self.sp_model.get_piece_size()
|
||||||
|
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_vocab
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
"""Returns vocab as a dict"""
|
"""Returns vocab as a dict"""
|
||||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||||
vocab.update(self.added_tokens_encoder)
|
vocab.update(self.added_tokens_encoder)
|
||||||
return vocab
|
return vocab
|
||||||
|
|
||||||
|
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text: TextInput
|
||||||
|
Simply calls PreTrainedTokenizer's method
|
||||||
|
"""
|
||||||
|
return super().tokenize(text, **kwargs)
|
||||||
|
|
||||||
def _tokenize(self, text, **kwargs):
|
def _tokenize(self, text, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
Args:
|
||||||
|
text: TextInput
|
||||||
Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
|
Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
|
||||||
"""
|
"""
|
||||||
return self.sp_model.encode(text, out_type=str)
|
return self.sp_model.encode(text, out_type=str)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_token_to_id
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
"""Converts a token (str) in an id using the vocab."""
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
return self.sp_model.piece_to_id(token)
|
return self.sp_model.piece_to_id(token)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer._convert_id_to_token
|
|
||||||
def _convert_id_to_token(self, index):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
token = self.sp_model.IdToPiece(index)
|
token = self.sp_model.IdToPiece(index)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def _decode(
|
|
||||||
self,
|
|
||||||
token_ids: List[int],
|
|
||||||
skip_special_tokens: bool = False,
|
|
||||||
spaces_between_special_tokens: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> str:
|
|
||||||
sub_texts = []
|
|
||||||
current_sub_text = []
|
|
||||||
for ids in token_ids:
|
|
||||||
if skip_special_tokens and ids in self.all_special_ids:
|
|
||||||
continue
|
|
||||||
if ids in self._added_tokens_decoder:
|
|
||||||
if current_sub_text:
|
|
||||||
sub_texts.append(self.sp_model.decode(current_sub_text))
|
|
||||||
sub_texts.append(self._added_tokens_decoder[ids].content)
|
|
||||||
current_sub_text = []
|
|
||||||
else:
|
|
||||||
current_sub_text.append(ids)
|
|
||||||
if current_sub_text:
|
|
||||||
sub_texts.append(self.sp_model.decode(current_sub_text))
|
|
||||||
|
|
||||||
if spaces_between_special_tokens:
|
|
||||||
sub_texts = " ".join(sub_texts)
|
|
||||||
else:
|
|
||||||
sub_texts = "".join(sub_texts)
|
|
||||||
|
|
||||||
return sub_texts.replace(SPIECE_UNDERLINE, " ")
|
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of tokens (string) in a single string."""
|
"""Converts a sequence of tokens (string) in a single string."""
|
||||||
current_sub_tokens = []
|
current_sub_tokens = []
|
||||||
@ -214,7 +192,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
|||||||
out_string += self.sp_model.decode(current_sub_tokens)
|
out_string += self.sp_model.decode(current_sub_tokens)
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.save_vocabulary
|
|
||||||
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
"""
|
"""
|
||||||
Save the vocabulary and special tokens file to a directory.
|
Save the vocabulary and special tokens file to a directory.
|
||||||
@ -242,7 +219,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
return (out_vocab_file,)
|
return (out_vocab_file,)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||||
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
||||||
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
||||||
@ -254,7 +230,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.get_special_tokens_mask
|
|
||||||
def get_special_tokens_mask(
|
def get_special_tokens_mask(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
@ -292,7 +267,6 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
|||||||
+ eos_token_id
|
+ eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.create_token_type_ids_from_sequences
|
|
||||||
def create_token_type_ids_from_sequences(
|
def create_token_type_ids_from_sequences(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
@ -325,3 +299,35 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
|||||||
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _decode(
|
||||||
|
self,
|
||||||
|
token_ids: List[int],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
spaces_between_special_tokens: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
sub_texts = []
|
||||||
|
current_sub_text = []
|
||||||
|
for ids in token_ids:
|
||||||
|
if skip_special_tokens and ids in self.all_special_ids:
|
||||||
|
continue
|
||||||
|
if ids in self._added_tokens_decoder:
|
||||||
|
if current_sub_text:
|
||||||
|
sub_texts.append(self.sp_model.decode(current_sub_text))
|
||||||
|
sub_texts.append(self._added_tokens_decoder[ids].content)
|
||||||
|
current_sub_text = []
|
||||||
|
else:
|
||||||
|
current_sub_text.append(ids)
|
||||||
|
if current_sub_text:
|
||||||
|
sub_texts.append(self.sp_model.decode(current_sub_text))
|
||||||
|
|
||||||
|
if spaces_between_special_tokens:
|
||||||
|
sub_texts = " ".join(sub_texts)
|
||||||
|
else:
|
||||||
|
sub_texts = "".join(sub_texts)
|
||||||
|
|
||||||
|
return sub_texts.replace(SPIECE_UNDERLINE, " ")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["GemmaTokenizer"]
|
||||||
|
@ -58,7 +58,6 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlipVideo
|
|
||||||
class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
|
class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`].
|
Class defining the outputs of [`InstructBlipVideoForConditionalGeneration`].
|
||||||
@ -91,7 +90,6 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlipVideo
|
|
||||||
class InstructBlipVideoVisionEmbeddings(nn.Module):
|
class InstructBlipVideoVisionEmbeddings(nn.Module):
|
||||||
def __init__(self, config: InstructBlipVideoVisionConfig):
|
def __init__(self, config: InstructBlipVideoVisionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -166,7 +164,6 @@ class InstructBlipVideoVisionEmbeddings(nn.Module):
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlipVideo
|
|
||||||
class InstructBlipVideoAttention(nn.Module):
|
class InstructBlipVideoAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
@ -248,7 +245,6 @@ class InstructBlipVideoAttention(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipMLP
|
|
||||||
class InstructBlipVideoMLP(nn.Module):
|
class InstructBlipVideoMLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -264,7 +260,6 @@ class InstructBlipVideoMLP(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlipVideo
|
|
||||||
class InstructBlipVideoEncoderLayer(nn.Module):
|
class InstructBlipVideoEncoderLayer(nn.Module):
|
||||||
def __init__(self, config: InstructBlipVideoConfig):
|
def __init__(self, config: InstructBlipVideoConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -330,7 +325,6 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
|
|||||||
]
|
]
|
||||||
_keep_in_fp32_modules = []
|
_keep_in_fp32_modules = []
|
||||||
|
|
||||||
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlipVideo
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
factor = self.config.initializer_range
|
factor = self.config.initializer_range
|
||||||
@ -450,7 +444,6 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlipVideo
|
|
||||||
class InstructBlipVideoEncoder(nn.Module):
|
class InstructBlipVideoEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||||
@ -537,7 +530,6 @@ class InstructBlipVideoEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->InstructBlipVideo, BLIP->INSTRUCTBLIPVIDEO
|
|
||||||
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
|
class InstructBlipVideoVisionModel(InstructBlipVideoPreTrainedModel):
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
config_class = InstructBlipVideoVisionConfig
|
config_class = InstructBlipVideoVisionConfig
|
||||||
@ -738,7 +730,6 @@ class InstructBlipVideoQFormerMultiHeadAttention(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipVideoQFormer
|
|
||||||
class InstructBlipVideoQFormerSelfOutput(nn.Module):
|
class InstructBlipVideoQFormerSelfOutput(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -753,7 +744,6 @@ class InstructBlipVideoQFormerSelfOutput(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlipVideo
|
|
||||||
class InstructBlipVideoQFormerAttention(nn.Module):
|
class InstructBlipVideoQFormerAttention(nn.Module):
|
||||||
def __init__(self, config, is_cross_attention=False):
|
def __init__(self, config, is_cross_attention=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -803,7 +793,6 @@ class InstructBlipVideoQFormerAttention(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipVideoQFormer
|
|
||||||
class InstructBlipVideoQFormerIntermediate(nn.Module):
|
class InstructBlipVideoQFormerIntermediate(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -819,7 +808,6 @@ class InstructBlipVideoQFormerIntermediate(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipVideoQFormer
|
|
||||||
class InstructBlipVideoQFormerOutput(nn.Module):
|
class InstructBlipVideoQFormerOutput(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -937,7 +925,6 @@ class InstructBlipVideoQFormerLayer(nn.Module):
|
|||||||
return layer_output
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlipVideo
|
|
||||||
class InstructBlipVideoQFormerEncoder(nn.Module):
|
class InstructBlipVideoQFormerEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -43,14 +43,12 @@ SPIECE_UNDERLINE = "▁"
|
|||||||
B_INST, E_INST = "[INST]", "[/INST]"
|
B_INST, E_INST = "[INST]", "[/INST]"
|
||||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
||||||
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
||||||
that your responses are socially unbiased and positive in nature.
|
that your responses are socially unbiased and positive in nature.
|
||||||
|
|
||||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
||||||
correct. If you don't know the answer to a question, please don't share false information."""
|
correct. If you don't know the answer to a question, please don't share false information.""" # fmt: skip
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaTokenizer(PreTrainedTokenizer):
|
class LlamaTokenizer(PreTrainedTokenizer):
|
||||||
|
@ -235,7 +235,6 @@ class LlavaNextVideoPooler(nn.Module):
|
|||||||
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
|
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNextVideo
|
|
||||||
class LlavaNextVideoMultiModalProjector(nn.Module):
|
class LlavaNextVideoMultiModalProjector(nn.Module):
|
||||||
def __init__(self, config: LlavaNextVideoConfig):
|
def __init__(self, config: LlavaNextVideoConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -272,7 +271,6 @@ LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
|
|||||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||||
LLAVA_NEXT_VIDEO_START_DOCSTRING,
|
LLAVA_NEXT_VIDEO_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNextVideo,llava->llava_next_video
|
|
||||||
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
||||||
config_class = LlavaNextVideoConfig
|
config_class = LlavaNextVideoConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
@ -426,35 +424,27 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
|||||||
raise ValueError(f"{padding_side} is not `left` or `right`.")
|
raise ValueError(f"{padding_side} is not `left` or `right`.")
|
||||||
self._padding_side = padding_side
|
self._padding_side = padding_side
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.language_model.get_input_embeddings()
|
return self.language_model.get_input_embeddings()
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
|
|
||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.language_model.set_input_embeddings(value)
|
self.language_model.set_input_embeddings(value)
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.language_model.get_output_embeddings()
|
return self.language_model.get_output_embeddings()
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
|
|
||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.language_model.set_output_embeddings(new_embeddings)
|
self.language_model.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
|
|
||||||
def set_decoder(self, decoder):
|
def set_decoder(self, decoder):
|
||||||
self.language_model.set_decoder(decoder)
|
self.language_model.set_decoder(decoder)
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
|
|
||||||
def get_decoder(self):
|
def get_decoder(self):
|
||||||
return self.language_model.get_decoder()
|
return self.language_model.get_decoder()
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
return self.language_model.tie_weights()
|
return self.language_model.tie_weights()
|
||||||
|
|
||||||
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings
|
|
||||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
||||||
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||||
# update vocab size
|
# update vocab size
|
||||||
|
@ -25,8 +25,8 @@ from torch import Tensor, nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...deepspeed import is_deepspeed_zero3_enabled
|
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
|
@ -25,8 +25,8 @@ from torch import Tensor, nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...deepspeed import is_deepspeed_zero3_enabled
|
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
|
@ -16,7 +16,8 @@ import argparse
|
|||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
from check_copies import run_ruff
|
from check_copies import run_ruff
|
||||||
@ -113,7 +114,11 @@ class ClassFinder(CSTVisitor):
|
|||||||
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
|
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches(
|
||||||
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
|
self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module()
|
||||||
):
|
):
|
||||||
self.assignments[node.body[0].targets[0].target.value] = node
|
if hasattr(node.body[0].targets[0].target, "value"):
|
||||||
|
self.assignments[node.body[0].targets[0].target.value] = node
|
||||||
|
else:
|
||||||
|
for idx, target in enumerate(list(node.body[0].targets[0].target.elements)):
|
||||||
|
self.assignments[target.value.value] = node.body[0].value.elements[idx].value
|
||||||
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
|
if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])):
|
||||||
self.imports[node.body[0].names] = node
|
self.imports[node.body[0].names] = node
|
||||||
|
|
||||||
@ -217,11 +222,21 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
|||||||
|
|
||||||
return compiled_regex.sub(replace, text)
|
return compiled_regex.sub(replace, text)
|
||||||
|
|
||||||
|
def convert_to_camelcase(self, text):
|
||||||
|
# Regex pattern to match consecutive uppercase letters and lowercase the first set
|
||||||
|
result = re.sub(r"^[A-Z]+(?=[A-Z][a-z])", lambda m: m.group(0).capitalize(), text, count=1)
|
||||||
|
return result
|
||||||
|
|
||||||
@m.leave(m.Name() | m.SimpleString() | m.Comment())
|
@m.leave(m.Name() | m.SimpleString() | m.Comment())
|
||||||
def replace_name(self, original_node, updated_node):
|
def replace_name(self, original_node, updated_node):
|
||||||
|
if re.findall(r"# Copied from", updated_node.value):
|
||||||
|
return cst.RemoveFromParent()
|
||||||
update = self.preserve_case_replace(updated_node.value)
|
update = self.preserve_case_replace(updated_node.value)
|
||||||
return updated_node.with_changes(value=update)
|
return updated_node.with_changes(value=update)
|
||||||
|
|
||||||
|
def leave_ClassDef(self, original_node, updated_node):
|
||||||
|
return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value)))
|
||||||
|
|
||||||
|
|
||||||
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
|
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
|
||||||
"""Helper function to rename and then parse a source file using the ClassFinder"""
|
"""Helper function to rename and then parse a source file using the ClassFinder"""
|
||||||
@ -251,6 +266,63 @@ def SUPER_CALL_NODE(func_name):
|
|||||||
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name)))
|
||||||
|
|
||||||
|
|
||||||
|
def is_call_to_super(node, func_name):
|
||||||
|
return m.matches(
|
||||||
|
node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Transformer class to replace ClassB.call_to_method and ClassB().call_to_method with super().call_to_method
|
||||||
|
class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||||
|
def __init__(self, all_bases: Set[str]):
|
||||||
|
self.all_bases = all_bases
|
||||||
|
|
||||||
|
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
|
||||||
|
# Handle ClassB.call_to_method
|
||||||
|
if (
|
||||||
|
isinstance(original_node.value, cst.Name)
|
||||||
|
and original_node.value.value in self.all_bases
|
||||||
|
and isinstance(original_node.attr, cst.Name)
|
||||||
|
):
|
||||||
|
# Replace with super().call_to_method
|
||||||
|
return updated_node.with_changes(
|
||||||
|
value=cst.Call(cst.Name("super")),
|
||||||
|
)
|
||||||
|
# Handle ClassB().call_to_method
|
||||||
|
elif (
|
||||||
|
isinstance(original_node.value, cst.Call)
|
||||||
|
and isinstance(original_node.value.func, cst.Name)
|
||||||
|
and original_node.value.func.value in self.all_bases
|
||||||
|
and isinstance(original_node.attr, cst.Name)
|
||||||
|
):
|
||||||
|
# Replace with super().call_to_method
|
||||||
|
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
|
||||||
|
return updated_node
|
||||||
|
|
||||||
|
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||||
|
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
|
||||||
|
if isinstance(original_node.func, cst.Attribute) and (
|
||||||
|
# Match ClassB().func_a(...)
|
||||||
|
(
|
||||||
|
isinstance(original_node.func.value, cst.Call)
|
||||||
|
and isinstance(original_node.func.value.func, cst.Name)
|
||||||
|
and original_node.func.value.func.value in self.all_bases
|
||||||
|
)
|
||||||
|
or
|
||||||
|
# Match ClassB.func_a(...)
|
||||||
|
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
|
||||||
|
):
|
||||||
|
# Check if the first argument is 'self', and remove it
|
||||||
|
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
|
||||||
|
# Create the new argument list without 'self'
|
||||||
|
new_args = updated_node.args[1:]
|
||||||
|
else:
|
||||||
|
new_args = updated_node.args
|
||||||
|
|
||||||
|
return updated_node.with_changes(args=new_args)
|
||||||
|
return updated_node
|
||||||
|
|
||||||
|
|
||||||
def get_docstring_indent(docstring):
|
def get_docstring_indent(docstring):
|
||||||
# Match the first line after the opening triple quotes
|
# Match the first line after the opening triple quotes
|
||||||
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
|
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
|
||||||
@ -263,7 +335,7 @@ def get_docstring_indent(docstring):
|
|||||||
def merge_docstrings(original_docstring, updated_docstring):
|
def merge_docstrings(original_docstring, updated_docstring):
|
||||||
# indent_level = get_docstring_indent(updated_docstring)
|
# indent_level = get_docstring_indent(updated_docstring)
|
||||||
original_level = get_docstring_indent(original_docstring)
|
original_level = get_docstring_indent(original_docstring)
|
||||||
if " Args:\n " not in updated_docstring:
|
if not re.findall(r"\n\s*Args:\n", updated_docstring):
|
||||||
# Split the docstring at the example section, assuming `"""` is used to define the docstring
|
# Split the docstring at the example section, assuming `"""` is used to define the docstring
|
||||||
parts = original_docstring.split("```")
|
parts = original_docstring.split("```")
|
||||||
if "```" in updated_docstring and len(parts) > 1:
|
if "```" in updated_docstring and len(parts) > 1:
|
||||||
@ -292,13 +364,15 @@ def merge_docstrings(original_docstring, updated_docstring):
|
|||||||
class SuperTransformer(cst.CSTTransformer):
|
class SuperTransformer(cst.CSTTransformer):
|
||||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||||
|
|
||||||
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name=""):
|
def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None):
|
||||||
self.python_module = python_module
|
self.python_module = python_module
|
||||||
self.original_methods = original_methods
|
self.original_methods = original_methods
|
||||||
self.updated_methods = updated_methods
|
self.updated_methods = updated_methods
|
||||||
self.all_assign_target = {}
|
self.all_assign_target = {}
|
||||||
self.deleted_targets = {} # child node can delete some arguments
|
self.deleted_targets = {} # child node can delete some arguments
|
||||||
self.class_name = class_name
|
self.class_name = class_name
|
||||||
|
self.all_bases = all_bases or []
|
||||||
|
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases))
|
||||||
|
|
||||||
def update_body(self, existing_body, new_statements):
|
def update_body(self, existing_body, new_statements):
|
||||||
"""
|
"""
|
||||||
@ -356,18 +430,14 @@ class SuperTransformer(cst.CSTTransformer):
|
|||||||
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
|
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
|
||||||
new_body = []
|
new_body = []
|
||||||
has_super_call = False
|
has_super_call = False
|
||||||
for idx, expr in enumerate(node.body):
|
|
||||||
if m.matches(
|
for expr in node.body:
|
||||||
expr,
|
if is_call_to_super(expr, func_name):
|
||||||
m.SimpleStatementLine(
|
|
||||||
body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]
|
|
||||||
),
|
|
||||||
):
|
|
||||||
if idx != 0 and func_name == "__init__":
|
|
||||||
raise ValueError(f"The call to super() in {self.class_name} should be at the top of the init")
|
|
||||||
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
|
|
||||||
has_super_call = True
|
has_super_call = True
|
||||||
elif m.matches(expr, DOCSTRING_NODE):
|
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body))
|
||||||
|
else:
|
||||||
|
expr = expr.visit(self.transformer)
|
||||||
|
if m.matches(expr, DOCSTRING_NODE):
|
||||||
self.has_docstring = True
|
self.has_docstring = True
|
||||||
if parent_has_docstring: # actually here we ought to de-duplicate?
|
if parent_has_docstring: # actually here we ought to de-duplicate?
|
||||||
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
|
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
|
||||||
@ -406,15 +476,17 @@ class SuperTransformer(cst.CSTTransformer):
|
|||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
|
|
||||||
def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str):
|
def replace_call_to_super(
|
||||||
|
class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Given the `class_name`, the `updated_node`'s call to super are unpacked.
|
Given the `class_name`, the `updated_node`'s call to super are unpacked.
|
||||||
|
|
||||||
| ```python | | ```python
|
| ```python | | ```python
|
||||||
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
|
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
|
||||||
| def __init__(self): | | def __init__(self):
|
| def __init__(self): | | def __init__(self):
|
||||||
Going from: | self.dropout = 0.2 | to: | self.dropout = 0.2
|
Going from: | super().__init__() | to: | super().__init__(config)
|
||||||
| super().__init__() | | super().__init__(config)
|
| self.dropout = 0.2 | | self.dropout = 0.2
|
||||||
| ``` | | self.padding_idx = config.pad_token_id
|
| ``` | | self.padding_idx = config.pad_token_id
|
||||||
| self.vocab_size = config.vocab_size
|
| self.vocab_size = config.vocab_size
|
||||||
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
@ -453,7 +525,14 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
|||||||
new_params = new_params.with_changes(
|
new_params = new_params.with_changes(
|
||||||
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
|
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
|
||||||
)
|
)
|
||||||
func = func.with_changes(body=updated_methods[name].body, params=new_params)
|
if not re.match(
|
||||||
|
r"\ndef .*\(.*\):\n raise.*Error\(.*",
|
||||||
|
class_finder.python_module.code_for_node(updated_methods[name]),
|
||||||
|
):
|
||||||
|
func = func.with_changes(body=updated_methods[name].body, params=new_params)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||||
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
|
target = class_finder.python_module.code_for_node(func.body[0].targets[0])
|
||||||
assign_targets[target] = func
|
assign_targets[target] = func
|
||||||
@ -492,7 +571,7 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
|||||||
temp_module = cst.Module(body=[result_node])
|
temp_module = cst.Module(body=[result_node])
|
||||||
new_module = MetadataWrapper(temp_module)
|
new_module = MetadataWrapper(temp_module)
|
||||||
new_replacement_class = new_module.visit(
|
new_replacement_class = new_module.visit(
|
||||||
SuperTransformer(temp_module, original_methods, updated_methods, class_name)
|
SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases)
|
||||||
)
|
)
|
||||||
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
||||||
|
|
||||||
@ -508,6 +587,31 @@ TYPE_TO_FILE_TYPE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_part(class_name, base_class):
|
||||||
|
"""
|
||||||
|
When `MyClassNameAttention` inherits from `MistralAttention`, we need
|
||||||
|
to process the name to properly find dependencies.
|
||||||
|
|
||||||
|
Here we take what is the same (Attention) and what is different
|
||||||
|
when finding the dependencies.
|
||||||
|
"""
|
||||||
|
common_suffix_len = 0
|
||||||
|
for i in range(1, min(len(class_name), len(base_class)) + 1):
|
||||||
|
if class_name[-i] == base_class[-i]:
|
||||||
|
common_suffix_len += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if common_suffix_len > 0:
|
||||||
|
new_part = class_name[:-common_suffix_len]
|
||||||
|
else:
|
||||||
|
new_part = class_name
|
||||||
|
|
||||||
|
# Convert the remaining new part to snake_case
|
||||||
|
snake_case = re.sub(r"(?<!^)(?=[A-Z])", "_", new_part).lower()
|
||||||
|
return snake_case
|
||||||
|
|
||||||
|
|
||||||
class ModularConverterTransformer(CSTTransformer):
|
class ModularConverterTransformer(CSTTransformer):
|
||||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||||
|
|
||||||
@ -538,6 +642,7 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
}
|
}
|
||||||
self.match_patterns = "|".join(self.files.keys())
|
self.match_patterns = "|".join(self.files.keys())
|
||||||
self.all_definitions = {}
|
self.all_definitions = {}
|
||||||
|
self.class_to_file_type = {}
|
||||||
|
|
||||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||||
"""When visiting imports from `transformers.models.xxx` we need to:
|
"""When visiting imports from `transformers.models.xxx` we need to:
|
||||||
@ -630,13 +735,33 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
self.given_new_name,
|
self.given_new_name,
|
||||||
)
|
)
|
||||||
visited_module[super_file_name] = class_finder
|
visited_module[super_file_name] = class_finder
|
||||||
|
list_dependencies = {
|
||||||
|
dep: class_finder.class_start_line.get(dep, 1000)
|
||||||
|
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||||
|
}
|
||||||
else: # we are re-using the previously parsed data
|
else: # we are re-using the previously parsed data
|
||||||
class_finder = visited_module[super_file_name]
|
class_finder = visited_module[super_file_name]
|
||||||
|
|
||||||
list_dependencies = {
|
list_dependencies = {
|
||||||
dep: class_finder.class_start_line.get(dep, 1000)
|
dep: class_finder.class_start_line.get(dep, 1000)
|
||||||
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||||
}
|
}
|
||||||
|
if list_dependencies == []:
|
||||||
|
# so, maybe standard renaming did not work (the class name is different)
|
||||||
|
# we try with another renaming pattern
|
||||||
|
potential_given_name = get_new_part(class_name, super_class)
|
||||||
|
del visited_module[super_file_name]
|
||||||
|
class_finder = find_classes_in_file(
|
||||||
|
self.transformers_imports[super_file_name],
|
||||||
|
model_name,
|
||||||
|
potential_given_name,
|
||||||
|
self.model_name,
|
||||||
|
potential_given_name,
|
||||||
|
)
|
||||||
|
list_dependencies = {
|
||||||
|
dep: class_finder.class_start_line.get(dep, 1000)
|
||||||
|
for dep in class_finder.class_dependency_mapping.get(class_name, [])
|
||||||
|
}
|
||||||
|
|
||||||
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
|
list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True)
|
||||||
start_insert_idx = self.global_scope_index
|
start_insert_idx = self.global_scope_index
|
||||||
@ -668,10 +793,12 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
self.inserted_deps.append(dependency)
|
self.inserted_deps.append(dependency)
|
||||||
|
|
||||||
if len(list_dependencies) > 0:
|
if len(list_dependencies) > 0:
|
||||||
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
|
updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unable to find dependencies for {super_class} in {super_file_name}. Here are the dependencies found: {class_finder.class_dependency_mapping}. (The automatic renaming might have gone wrong!)"
|
f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})"
|
||||||
|
f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}."
|
||||||
|
f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now, if a class was defined without parents, we look for the name
|
# Now, if a class was defined without parents, we look for the name
|
||||||
@ -679,8 +806,10 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
match = re.search(rf"({match_pattern})$", class_name)
|
match = re.search(rf"({match_pattern})$", class_name)
|
||||||
if match:
|
if match:
|
||||||
key = TYPE_TO_FILE_TYPE[match.group(1)]
|
key = TYPE_TO_FILE_TYPE[match.group(1)]
|
||||||
|
self.class_to_file_type[class_name] = key
|
||||||
self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||||
else:
|
else:
|
||||||
|
self.class_to_file_type[class_name] = "modeling"
|
||||||
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
@ -690,14 +819,37 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
self.all_definitions[node.name.value] = node
|
self.all_definitions[node.name.value] = node
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def visit_Assign(self, node: cst.Assign) -> None:
|
||||||
|
# Check if the assignment target is '__all__'
|
||||||
|
if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__":
|
||||||
|
if isinstance(node.value, cst.List):
|
||||||
|
# Extract the elements from the list
|
||||||
|
all_all_to_add = defaultdict(list)
|
||||||
|
for elt in node.value.elements:
|
||||||
|
if isinstance(elt.value, cst.SimpleString):
|
||||||
|
# Remove quotes and add the string to the elements list
|
||||||
|
class_name = elt.value.value
|
||||||
|
file = self.class_to_file_type[
|
||||||
|
elt.value.evaluated_value
|
||||||
|
] # evaluated value give the content of the string
|
||||||
|
all_all_to_add[file] += [class_name]
|
||||||
|
for f_type, new_alls in all_all_to_add.items():
|
||||||
|
updated_node = node.with_changes(
|
||||||
|
value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls])
|
||||||
|
)
|
||||||
|
self.files[f_type][class_name] = {
|
||||||
|
"insert_idx": self.global_scope_index + 100,
|
||||||
|
"node": updated_node,
|
||||||
|
}
|
||||||
|
|
||||||
def leave_If(self, original_node, node):
|
def leave_If(self, original_node, node):
|
||||||
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node)
|
||||||
if m.matches(parent_node, m.Module()):
|
if m.matches(parent_node, m.Module()):
|
||||||
full_statement = self.python_module.code_for_node(original_node.test)
|
full_statement = self.python_module.code_for_node(original_node.test)
|
||||||
if re.search(r"[\s\S]*is_.*available", full_statement):
|
if re.search(r"[\s\S]*is_.*available", full_statement):
|
||||||
self.all_safe_imports.append(node)
|
self.all_safe_imports.append(node)
|
||||||
elif full_statement not in self.new_body:
|
elif full_statement not in self.all_imports:
|
||||||
self.new_body[node] = {"insert_idx": self.global_scope_index, "node": node}
|
logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}")
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def leave_Module(self, original_node: cst.Assign, node):
|
def leave_Module(self, original_node: cst.Assign, node):
|
||||||
@ -764,7 +916,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--files_to_parse",
|
"--files_to_parse",
|
||||||
default=["examples/modular-transformers/modular_dummy.py"],
|
default=["src/transformers/models/gemma/modular_gemma.py"],
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||||
)
|
)
|
||||||
|
@ -373,7 +373,6 @@ src/transformers/data/processors/squad.py
|
|||||||
src/transformers/data/processors/utils.py
|
src/transformers/data/processors/utils.py
|
||||||
src/transformers/data/processors/xnli.py
|
src/transformers/data/processors/xnli.py
|
||||||
src/transformers/debug_utils.py
|
src/transformers/debug_utils.py
|
||||||
src/transformers/deepspeed.py
|
|
||||||
src/transformers/dependency_versions_check.py
|
src/transformers/dependency_versions_check.py
|
||||||
src/transformers/dependency_versions_table.py
|
src/transformers/dependency_versions_table.py
|
||||||
src/transformers/dynamic_module_utils.py
|
src/transformers/dynamic_module_utils.py
|
||||||
|
Loading…
Reference in New Issue
Block a user