From 1dba608df93ffb10a9c268ef35191adf2424c5ca Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:43:55 +0200 Subject: [PATCH] [`modular`] fixes! (#33820) * fix converter for function definitions * small changes * no prints * style --- .../configuration_my_new_model.py | 17 +- .../configuration_my_new_model2.py | 110 +++++- .../configuration_new_model.py | 8 +- .../modular-transformers/convert_examples.sh | 2 +- .../modular-transformers/modeling_dummy.py | 58 ++-- .../modeling_dummy_bert.py | 7 +- .../modeling_my_new_model2.py | 314 +++++++++--------- .../modular-transformers/modular_dummy.py | 42 +-- utils/modular_model_converter.py | 11 +- 9 files changed, 322 insertions(+), 247 deletions(-) diff --git a/examples/modular-transformers/configuration_my_new_model.py b/examples/modular-transformers/configuration_my_new_model.py index d7c946dbe31..3c7848e6956 100644 --- a/examples/modular-transformers/configuration_my_new_model.py +++ b/examples/modular-transformers/configuration_my_new_model.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. One of our CI enforces this +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 from ...configuration_utils import PretrainedConfig @@ -111,8 +111,6 @@ class MyNewModelConfig(PretrainedConfig): Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. head_dim (`int`, *optional*): The attention head dimension. If None, it will default to hidden_size // num_heads - new_param (`int`, *optional*, defaults to `False`): - A fun new parameter ```python >>> from transformers import MyNewModelModel, MyNewModelConfig @@ -125,7 +123,10 @@ class MyNewModelConfig(PretrainedConfig): >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + new_param (`int`, *optional*, defaults to `False`): + A fun new parameter + """ model_type = "my_new_model" keys_to_ignore_at_inference = ["past_key_values"] @@ -178,12 +179,14 @@ class MyNewModelConfig(PretrainedConfig): self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) + self.new_param = new_param super().__init__( pad_token_id=pad_token_id, @@ -192,5 +195,3 @@ class MyNewModelConfig(PretrainedConfig): tie_word_embeddings=tie_word_embeddings, **kwargs, ) - self.mlp_bias = mlp_bias - self.new_param = new_param diff --git a/examples/modular-transformers/configuration_my_new_model2.py b/examples/modular-transformers/configuration_my_new_model2.py index b940d8d93b3..5fef1cecc70 100644 --- a/examples/modular-transformers/configuration_my_new_model2.py +++ b/examples/modular-transformers/configuration_my_new_model2.py @@ -1,15 +1,116 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. One of our CI enforces this +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + from ...configuration_utils import PretrainedConfig from ...modeling_rope_utils import rope_config_validation class MyNewModel2Config(PretrainedConfig): r""" + This is the configuration class to store the configuration of a [`MyNewModel2Model`]. It is used to instantiate an MyNewModel2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MyNewModel2-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MyNewModel2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MyNewModel2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. MyNewModel2 1 supports up to 2048 tokens, + MyNewModel2 2 up to 4096, CodeMyNewModel2 up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'my_new_model23'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'my_new_model23'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'my_new_model23'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'my_new_model23'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_heads This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma-7B. @@ -20,6 +121,7 @@ class MyNewModel2Config(PretrainedConfig): vocab_size (`int`, *optional*, defaults to 256000): Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GemmaModel`] + ```python >>> from transformers import GemmaModel, GemmaConfig >>> # Initializing a Gemma gemma-7b style configuration @@ -83,7 +185,7 @@ class MyNewModel2Config(PretrainedConfig): self.mlp_bias = mlp_bias self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads # Validate the correctness of rotary position embeddings parameters - # BC: if there is a 'type' field, move it to 'rope_type'. + # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] rope_config_validation(self) diff --git a/examples/modular-transformers/configuration_new_model.py b/examples/modular-transformers/configuration_new_model.py index 7d57f9fe25b..8bc8ef52cee 100644 --- a/examples/modular-transformers/configuration_new_model.py +++ b/examples/modular-transformers/configuration_new_model.py @@ -1,12 +1,12 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. One of our CI enforces this +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Example where we only want to overwrite the defaults of an init -from transformers import PretrainedConfig +from ...configuration_utils import PretrainedConfig class NewModelConfig(PretrainedConfig): diff --git a/examples/modular-transformers/convert_examples.sh b/examples/modular-transformers/convert_examples.sh index 4af31f1b426..49666ab1154 100644 --- a/examples/modular-transformers/convert_examples.sh +++ b/examples/modular-transformers/convert_examples.sh @@ -5,6 +5,6 @@ for file in examples/modular-transformers/modular_*; do # Check if it's a regular file if [ -f "$file" ]; then # Call the Python script with the file name as an argument - python utils/diff_model_converter.py --files_to_parse "$file" + python utils/modular_model_converter.py --files_to_parse "$file" fi done \ No newline at end of file diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 5dd76c60303..c67787fbd8a 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -1,11 +1,11 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. One of our CI enforces this +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + import math -from math import log from typing import List, Optional, Tuple, Union import torch @@ -31,11 +31,6 @@ from ...utils import ( from .configuration_dummy import DummyConfig -def _pre_process_input(input_ids): - print(log(input_ids)) - return input_ids - - logger = logging.get_logger(__name__) @@ -129,7 +124,7 @@ class DummyRotaryEmbedding(nn.Module): if config is None: logger.warning_once( "`DummyRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" + "`config` argument. All other arguments will be removed in v4.46" ) self.rope_kwargs = { "rope_type": rope_type, @@ -201,8 +196,8 @@ class DummyRotaryEmbedding(nn.Module): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x1 = x[..., : x.shape[-1] // 4] + x2 = x[..., x.shape[-1] // 4 :] return torch.cat((-x2, x1), dim=-1) @@ -308,7 +303,7 @@ class DummyAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = DummyRotaryEmbedding(config=self.config) def forward( @@ -320,7 +315,7 @@ class DummyAttention(nn.Module): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -355,7 +350,7 @@ class DummyAttention(nn.Module): logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) @@ -428,7 +423,7 @@ class DummyFlashAttention2(DummyAttention): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -455,7 +450,7 @@ class DummyFlashAttention2(DummyAttention): logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) @@ -541,7 +536,7 @@ class DummySdpaAttention(DummyAttention): output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: @@ -575,7 +570,7 @@ class DummySdpaAttention(DummyAttention): logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) @@ -650,7 +645,7 @@ class DummyDecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -796,7 +791,8 @@ DUMMY_INPUTS_DOCSTRING = r""" returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - - a [`~cache_utils.Cache`] instance; + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. @@ -877,7 +873,6 @@ class DummyModel(DummyPreTrainedModel): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - input_ids = _pre_process_input(input_ids) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -899,16 +894,19 @@ class DummyModel(DummyPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = False - if ( - use_cache and not isinstance(past_key_values, Cache) and not self.training - ): # kept for BC (non `Cache` `past_key_values` inputs) + if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" - ) + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index bdedd1f5f5a..611d7be961f 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. One of our CI enforces this +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import os @@ -1027,6 +1027,7 @@ class DummyBertModel(DummyBertPreTrainedModel): if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] + return super().forward(input_ids) return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index fea7994a53e..5484b3890fb 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -1,8 +1,8 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from . +# This file was automatically generated from . # Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the diff. If any change should be done, please apply the change to the -# diff.py file directly. One of our CI enforces this +# the file from the modular. If any change should be done, please apply the change to the +# modular_xxx.py file directly. One of our CI enforces this # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math from typing import List, Optional, Tuple, Union @@ -30,63 +30,6 @@ from ...utils import ( from .configuration_my_new_model2 import MyNewModel2Config -logger = logging.get_logger(__name__) - - -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - class MyNewModel2RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -107,6 +50,9 @@ class MyNewModel2RMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.eps}" +logger = logging.get_logger(__name__) + + class MyNewModel2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -305,6 +251,94 @@ class MyNewModel2Attention(nn.Module): return attn_output, attn_weights, past_key_value +class MyNewModel2SdpaAttention(MyNewModel2Attention): + """ + MyNewModel2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MyNewModel2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MyNewModel2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MyNewModel2Model is using MyNewModel2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + class MyNewModel2FlashAttention2(MyNewModel2Attention): """ MyNewModel2 flash attention module. This module inherits from `MyNewModel2Attention` as the weights of the module stays @@ -405,7 +439,6 @@ class MyNewModel2FlashAttention2(MyNewModel2Attention): is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -415,92 +448,57 @@ class MyNewModel2FlashAttention2(MyNewModel2Attention): return attn_output, attn_weights, past_key_value -class MyNewModel2SdpaAttention(MyNewModel2Attention): +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): """ - MyNewModel2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MyNewModel2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. """ - - # Adapted from MyNewModel2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MyNewModel2Model is using MyNewModel2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value + return causal_mask MY_NEW_MODEL2_ATTENTION_CLASSES = { @@ -514,11 +512,9 @@ class MyNewModel2DecoderLayer(nn.Module): def __init__(self, config: MyNewModel2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MY_NEW_MODEL2_ATTENTION_CLASSES[config._attn_implementation]( config=config, layer_idx=layer_idx ) - self.mlp = MyNewModel2MLP(config) self.input_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -673,7 +669,8 @@ MY_NEW_MODEL2_INPUTS_DOCSTRING = r""" returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - - a [`~cache_utils.Cache`] instance; + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. @@ -774,12 +771,19 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = False # noqa: F841 - if ( - use_cache and not isinstance(past_key_values, Cache) and not self.training - ): # kept for BC (non `Cache` `past_key_values` inputs) + if use_cache and not isinstance(past_key_values, Cache): return_legacy_cache = True # noqa: F841 - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -802,15 +806,6 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel): # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer - if ( - use_cache and not isinstance(past_key_values, Cache) and not self.training - ): # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = True - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " - "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" - ) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -922,6 +917,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel): cache_position=cache_position, batch_size=input_tensor.shape[0], ) + if ( self.config._attn_implementation == "sdpa" and attention_mask is not None @@ -970,7 +966,7 @@ class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel): @add_start_docstrings_to_model_forward(MY_NEW_MODEL2_INPUTS_DOCSTRING) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, diff --git a/examples/modular-transformers/modular_dummy.py b/examples/modular-transformers/modular_dummy.py index 33dc38d0b44..fb64ba4d856 100644 --- a/examples/modular-transformers/modular_dummy.py +++ b/examples/modular-transformers/modular_dummy.py @@ -1,45 +1,15 @@ -from math import log -from typing import List, Optional, Tuple, Union - import torch -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaModel -from ...cache_utils import Cache - -def _pre_process_input(input_ids): - print(log(input_ids)) - return input_ids +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 4] + x2 = x[..., x.shape[-1] // 4 :] + return torch.cat((-x2, x1), dim=-1) # example where we need some deps and some functions class DummyModel(LlamaModel): - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - input_ids = _pre_process_input(input_ids) - - return super().forward( - None, - attention_mask, - position_ids, - past_key_values, - inputs_embeds, - use_cache, - output_attentions, - output_hidden_states, - return_dict, - cache_position, - ) + pass diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index c5bf769f928..1bfc1230a91 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -537,7 +537,7 @@ class ModularConverterTransformer(CSTTransformer): "feature_extractor": {}, } self.match_patterns = "|".join(self.files.keys()) - self.all_functions = {} + self.all_definitions = {} def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """When visiting imports from `transformers.models.xxx` we need to: @@ -647,6 +647,7 @@ class ModularConverterTransformer(CSTTransformer): node = class_finder.global_nodes.get(dependency, None) if node is not None: if dependency not in file_to_update: + node = self.all_definitions.get(dependency, node) start_insert_idx -= 1 file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node} elif dependency not in self.inserted_deps: @@ -683,6 +684,12 @@ class ModularConverterTransformer(CSTTransformer): self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} return updated_node + def leave_FunctionDef(self, original_node, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + self.all_definitions[node.name.value] = node + return node + def leave_If(self, original_node, node): parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): @@ -757,7 +764,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["all"], + default=["examples/modular-transformers/modular_dummy.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", )