diff --git a/examples/modular-transformers/image_processing_new_imgproc_model.py b/examples/modular-transformers/image_processing_new_imgproc_model.py index 8966b454882..a64eb17861a 100644 --- a/examples/modular-transformers/image_processing_new_imgproc_model.py +++ b/examples/modular-transformers/image_processing_new_imgproc_model.py @@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) class ImgprocModelImageProcessor(BaseImageProcessor): r""" - Constructs a NEW_IMGPROC_MODEL image processor. + Constructs a IMGPROC_MODEL image processor. Args: do_resize (`bool`, *optional*, defaults to `True`): diff --git a/examples/modular-transformers/modeling_from_uppercase_model.py b/examples/modular-transformers/modeling_from_uppercase_model.py new file mode 100644 index 00000000000..d6c16c69743 --- /dev/null +++ b/examples/modular-transformers/modeling_from_uppercase_model.py @@ -0,0 +1,357 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_from_uppercase_model.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_from_uppercase_model.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Optional, Tuple + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 +from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging +from .configuration_from_uppercase_model import FromUppercaseModelConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +class FromUppercaseModelAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class FromUppercaseModelFlashAttention2(FromUppercaseModelAttention): + """ + FromUppercaseModelAttention flash attention module. This module inherits from `FromUppercaseModelAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + output_attentions = False + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=causal_attention_mask is not None, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class FromUppercaseModelSdpaAttention(FromUppercaseModelAttention): + """ + SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `FromUppercaseModelAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from FromUppercaseModelAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "FromUppercaseModelModel is using FromUppercaseModelSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " + "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " + 'be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + + # FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask` + if attention_mask is not None and causal_attention_mask is not None: + attn_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attn_mask = causal_attention_mask + else: + attn_mask = attention_mask + + bsz, tgt_len, embed_dim = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask` sequentially. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attn_mask, + dropout_p=self.dropout if self.training else 0.0, + scale=self.scale, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + +class FromUppercaseModelMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +FROM_UPPERCASE_MODEL_ATTENTION_CLASSES = { + "eager": FromUppercaseModelAttention, + "sdpa": FromUppercaseModelSdpaAttention, + "flash_attention_2": FromUppercaseModelFlashAttention2, +} + + +class FromUppercaseModelEncoderLayer(nn.Module): + def __init__(self, config: FromUppercaseModelConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = FROM_UPPERCASE_MODEL_ATTENTION_CLASSES[config._attn_implementation](config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = FromUppercaseModelMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py new file mode 100644 index 00000000000..02876996e0e --- /dev/null +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -0,0 +1,1017 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_multimodal1.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_multimodal1.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from .configuration_multimodal1 import Multimodal1TextConfig + + +logger = logging.get_logger(__name__) + + +class Multimodal1TextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Multimodal1TextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Multimodal1TextRotaryEmbedding(nn.Module): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[Multimodal1TextConfig] = None, + ): + super().__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`Multimodal1TextRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Multimodal1TextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Multimodal1TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Multimodal1TextConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) + self.rotary_emb = Multimodal1TextRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Multimodal1TextFlashAttention2(Multimodal1TextAttention): + """ + Multimodal1Text flash attention module. This module inherits from `Multimodal1TextAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Multimodal1TextRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Multimodal1TextSdpaAttention(Multimodal1TextAttention): + """ + Multimodal1Text attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Multimodal1TextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Multimodal1TextAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Multimodal1TextModel is using Multimodal1TextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MULTIMODAL1_TEXT_ATTENTION_CLASSES = { + "eager": Multimodal1TextAttention, + "flash_attention_2": Multimodal1TextFlashAttention2, + "sdpa": Multimodal1TextSdpaAttention, +} + + +class Multimodal1TextDecoderLayer(nn.Module): + def __init__(self, config: Multimodal1TextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MULTIMODAL1_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = Multimodal1TextMLP(config) + self.input_layernorm = Multimodal1TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Multimodal1TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MULTIMODAL1_TEXT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Multimodal1TextConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Multimodal1Text Model outputting raw hidden-states without any specific head on top.", + MULTIMODAL1_TEXT_START_DOCSTRING, +) +class Multimodal1TextPreTrainedModel(PreTrainedModel): + config_class = Multimodal1TextConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Multimodal1TextDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MULTIMODAL1_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Multimodal1Text Model outputting raw hidden-states without any specific head on top.", + MULTIMODAL1_TEXT_START_DOCSTRING, +) +class Multimodal1TextModel(Multimodal1TextPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Multimodal1TextDecoderLayer`] + + Args: + config: Multimodal1TextConfig + """ + + def __init__(self, config: Multimodal1TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Multimodal1TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Multimodal1TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Multimodal1TextRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + if getattr(config, "pretraining_tp", 1) != 1: + logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MULTIMODAL1_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask diff --git a/examples/modular-transformers/modeling_multimodal2.py b/examples/modular-transformers/modeling_multimodal2.py new file mode 100644 index 00000000000..b10b11b671a --- /dev/null +++ b/examples/modular-transformers/modeling_multimodal2.py @@ -0,0 +1,705 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from examples/modular-transformers/modular_multimodal2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_multimodal2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from transformers.utils import add_start_docstrings + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 +from ...utils import ( + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_multimodal2 import Multimodal2Config, Multimodal2VisionConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +class Multimodal2VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class Multimodal2VisionSdpaAttention(Multimodal2VisionAttention): + """ + SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Multimodal2VisionAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Multimodal2VisionAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Multimodal2VisionModel is using Multimodal2VisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " + "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " + 'be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + + # MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask` + if attention_mask is not None and causal_attention_mask is not None: + attn_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attn_mask = causal_attention_mask + else: + attn_mask = attention_mask + + bsz, tgt_len, embed_dim = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask` sequentially. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attn_mask, + dropout_p=self.dropout if self.training else 0.0, + scale=self.scale, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + +class Multimodal2VisionFlashAttention2(Multimodal2VisionAttention): + """ + Multimodal2VisionAttention flash attention module. This module inherits from `Multimodal2VisionAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + output_attentions = False + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=causal_attention_mask is not None, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class Multimodal2VisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +MULTIMODAL2_VISION_ATTENTION_CLASSES = { + "eager": Multimodal2VisionAttention, + "sdpa": Multimodal2VisionSdpaAttention, + "flash_attention_2": Multimodal2VisionFlashAttention2, +} + + +class Multimodal2VisionEncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = MULTIMODAL2_VISION_ATTENTION_CLASSES[config._attn_implementation](config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Multimodal2VisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Multimodal2VisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Multimodal2VisionEncoderLayer`]. + + Args: + config: Multimodal2VisionConfig + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class Multimodal2VisionEmbeddings(nn.Module): + def __init__(self, config: Multimodal2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + position_embedding = self.position_embedding.weight.unsqueeze(0) + num_positions = position_embedding.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embedding(self.position_ids) + + class_pos_embed = position_embedding[:, :1] + patch_pos_embed = position_embedding[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})." + ) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +MULTIMODAL2_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Multimodal2ImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class Multimodal2VisionTransformer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Multimodal2VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = Multimodal2VisionEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class Multimodal2VisionPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Multimodal2Config + base_model_prefix = "multimodal2_vision" + supports_gradient_checkpointing = True + _supports_sdpa = True + _supports_flash_attn_2 = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, Multimodal2VisionMLP): + pass + + +MULTIMODAL2_VISION_START_DOCSTRING = "doc" + + +@add_start_docstrings("New doc", MULTIMODAL2_VISION_START_DOCSTRING) +class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel): + config_class = Multimodal2VisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["Multimodal2VisionEncoderLayer"] + + def __init__(self, config: Multimodal2VisionConfig): + super().__init__(config) + self.vision_model = Multimodal2VisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Multimodal2VisionModel + + >>> model = Multimodal2VisionModel.from_pretrained("openai/multimodal2-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/multimodal2-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 4556308f1ea..d303d328e88 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -265,7 +265,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): min_dtype = torch.finfo(dtype).min sequence_length = inputs_embeds.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -358,9 +358,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): ```python >>> from PIL import Image >>> import requests - >>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration + >>> from transformers import AutoProcessor, NewTaskModelForNewTask - >>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/NewTaskModel-test-224px-hf") + >>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf") >>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf") >>> prompt = "answer en Where is the cow standing?" diff --git a/examples/modular-transformers/modular_from_uppercase_model.py b/examples/modular-transformers/modular_from_uppercase_model.py new file mode 100644 index 00000000000..ef3044e7ee2 --- /dev/null +++ b/examples/modular-transformers/modular_from_uppercase_model.py @@ -0,0 +1,6 @@ +from transformers.models.clip.modeling_clip import CLIPEncoderLayer + + +# Check if we can correctly grab dependencies with correct naming from all UPPERCASE old model +class FromUppercaseModelEncoderLayer(CLIPEncoderLayer): + pass diff --git a/examples/modular-transformers/modular_multimodal1.py b/examples/modular-transformers/modular_multimodal1.py new file mode 100644 index 00000000000..8f8eaf91a37 --- /dev/null +++ b/examples/modular-transformers/modular_multimodal1.py @@ -0,0 +1,6 @@ +from transformers.models.llama.modeling_llama import LlamaModel + + +# Check that we can correctly change the prefix (here add Text part at the end of the name) +class Multimodal1TextModel(LlamaModel): + pass diff --git a/examples/modular-transformers/modular_multimodal2.py b/examples/modular-transformers/modular_multimodal2.py new file mode 100644 index 00000000000..bc11e0b2869 --- /dev/null +++ b/examples/modular-transformers/modular_multimodal2.py @@ -0,0 +1,88 @@ +""" +Here, because clip is not consistent with the use of the "Text" and "Vision" prefixes, we cannot simply use +``` +class Multimodal2VisionModel(CLIPVisionModel): + pass +``` +with the hope that all dependencies will be renamed as `Multimodal2VisionClass`. For this reason, if we want consistency and +use the "Vision" part everywhere, we need to overwrite the intermediate classes and add the prefix everytime. +This adds noise to the modular, but is unfortunately unavoidable. +""" + +from torch import nn + +from transformers.models.clip.modeling_clip import ( + CLIPMLP, + CLIPAttention, + CLIPEncoder, + CLIPEncoderLayer, + CLIPFlashAttention2, + CLIPPreTrainedModel, + CLIPSdpaAttention, + CLIPVisionModel, + CLIPVisionTransformer, +) +from transformers.utils import add_start_docstrings + + +class Multimodal2VisionAttention(CLIPAttention): + pass + + +# Check that adding the second base class correctly set the parent, even though in clip it does not have the "Vision" part +class Multimodal2VisionSdpaAttention(CLIPSdpaAttention, Multimodal2VisionAttention): + pass + + +# Check that adding the second base class correctly set the parent, even though in clip it does not have the "Vision" part +class Multimodal2VisionFlashAttention2(CLIPFlashAttention2, Multimodal2VisionAttention): + pass + + +MULTIMODAL2_VISION_ATTENTION_CLASSES = { + "eager": Multimodal2VisionAttention, + "sdpa": Multimodal2VisionSdpaAttention, + "flash_attention_2": Multimodal2VisionFlashAttention2, +} + + +class Multimodal2VisionMLP(CLIPMLP): + pass + + +class Multimodal2VisionEncoderLayer(CLIPEncoderLayer): + def __init__(self, config): + super().__init__() + self.self_attn = MULTIMODAL2_VISION_ATTENTION_CLASSES[config._attn_implementation](config) + self.mlp = Multimodal2VisionMLP(config) + + +class Multimodal2VisionEncoder(CLIPEncoder): + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + +# Finally here the `Vision` part was correct in CLIP, but we still need to tell it that the encoder arg should use it as well +class Multimodal2VisionTransformer(CLIPVisionTransformer): + def __init__(self, config): + super().__init__(config) + self.encoder = Multimodal2VisionEncoder(config) + + +class Multimodal2VisionPreTrainedModel(CLIPPreTrainedModel): + def _init_weights(self, module): + if isinstance(module, Multimodal2VisionMLP): + pass + + +MULTIMODAL2_VISION_START_DOCSTRING = "doc" + + +# Here the only arg `self.vision_model = CLIPVisionTransformer(config)` in CLIPVisionModel already has the "Vision" part, so +# no need to overwrite it, it will look for `Multimodal2VisionTransformer` which has already being redefined above +# Note: we may want to redefine decorator as well for full consistency, as CLIP does not use "CLIP_VISION_START_DOCSTRING" but only +# "CLIP_START_DOCSTRING" +@add_start_docstrings("New doc", MULTIMODAL2_VISION_START_DOCSTRING) +class Multimodal2VisionModel(CLIPVisionModel, Multimodal2VisionPreTrainedModel): + _no_split_modules = ["Multimodal2VisionEncoderLayer"] diff --git a/src/transformers/models/gemma/configuration_gemma.py b/src/transformers/models/gemma/configuration_gemma.py index e170803ccca..346f386ba69 100644 --- a/src/transformers/models/gemma/configuration_gemma.py +++ b/src/transformers/models/gemma/configuration_gemma.py @@ -20,7 +20,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 5dd4ffe0c8a..58836a5631c 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -27,7 +27,6 @@ import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index bdf53376a1e..c042669e1ed 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -14,7 +14,6 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 89b4194d1b1..5ecffc8719b 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -34,7 +34,6 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 28e76ca19ac..cf1a0cfd95c 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -18,7 +18,7 @@ import importlib import os import re from abc import ABC, abstractmethod -from collections import defaultdict, deque +from collections import Counter, defaultdict, deque from typing import Dict, Set import libcst as cst @@ -48,7 +48,7 @@ def get_module_source_from_name(module_name: str) -> str: # Extract the source code from the module name spec = importlib.util.find_spec(module_name) if spec is None or spec.origin is None: - return f"Module {module_name} not found" + raise ValueError(f"Cannot open file associated with {module_name} module.") with open(spec.origin, "r", encoding="utf-8") as file: source_code = file.read() @@ -58,20 +58,40 @@ def get_module_source_from_name(module_name: str) -> str: def preserve_case_replace(text, patterns: dict, default_name: str): # Create a regex pattern to match all variations regex_pattern = "|".join(re.escape(key) for key in patterns.keys()) - compiled_regex = re.compile(regex_pattern, re.IGNORECASE) + compiled_regex = re.compile(f"({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL) def replace(match): - word = match.group(0) - result = patterns.get(word, default_name) - return result + matched_pattern = match.group(1) + next_char = match.group(2) + new_pattern = patterns.get(matched_pattern, default_name) + + # In this case, the cased old model did not respect CamelCase and was all UPPERCASE, so we need to rely on next char + # The heuristic is: if next char is not a letter, then it is not part of a model name and result should be `new_name`.upper() + if len(patterns) == 2 and matched_pattern.isupper(): + if not next_char.isalpha(): + # `new_name.upper()` is just the other entry for `matched_pattern.lower()`, uppercased + new_pattern = patterns[matched_pattern.lower()].upper() + + return new_pattern + next_char return compiled_regex.sub(replace, text) -def convert_to_camelcase(text, old_name: str, default_old_name: str): - # Regex pattern to match consecutive uppercase letters and lowercase the first set - result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1) - return result +def get_cased_name(lowercase_name: str) -> str: + """From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`.""" + if lowercase_name in CONFIG_MAPPING_NAMES: + return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "") + else: + return "".join(x.title() for x in lowercase_name.split("_")) + + +def get_lowercase_name(cased_name: str) -> str: + """From a model name in Camelcase in the format `MyModel`, return the lowercase name in the format `my_model`.""" + inverse_mapping = {value: key for key, value in CONFIG_MAPPING_NAMES.items()} + if cased_name + "Config" in inverse_mapping: + return inverse_mapping[cased_name + "Config"] + else: + return "_".join([s.lower() for s in re.findall(r"[A-Z][^A-Z]*", cased_name)]) class ReplaceNameTransformer(m.MatcherDecoratableTransformer): @@ -84,43 +104,47 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer): - LLaMa -> MyNewModel abd MyNewModel -> Llama """ - def __init__( - self, - old_name, - new_name, - given_old_name=None, - given_new_name=None, - ): + def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False): super().__init__() self.old_name = old_name self.new_name = new_name - self.default_name = "".join(x.title() for x in new_name.split("_")) - if self.new_name in CONFIG_MAPPING_NAMES: - self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace( - "Config", "" - ) # the best source of truth for class names. Could also just use the ones de + self.cased_new_name = get_cased_name(self.new_name) + self.cased_old_name = get_cased_name(self.old_name) self.patterns = { old_name: new_name, old_name.upper(): new_name.upper(), - "".join(x.title() for x in old_name.split("_")): self.default_name, + # For some old models, `self.cased_old_name` == `old_name.upper()` in which case this overwrite previous entry + self.cased_old_name: self.cased_new_name, } - if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns: - self.patterns[given_old_name] = given_new_name - if self.old_name in CONFIG_MAPPING_NAMES: - self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "") - if self.default_old_name.isupper(): - self.default_old_name = self.default_old_name.capitalize() + # In case new_name is a prefix alias, and not the original new model name + self.original_new_model_name = original_new_model_name + self.only_doc = only_doc - @m.leave(m.Name() | m.SimpleString() | m.Comment()) - def replace_name(self, original_node, updated_node): + def _replace_name(self, original_node, updated_node): if re.findall(r"# Copied from", updated_node.value): return cst.RemoveFromParent() - update = preserve_case_replace(updated_node.value, self.patterns, self.default_name) + update = preserve_case_replace(updated_node.value, self.patterns, self.cased_new_name) return updated_node.with_changes(value=update) - def leave_ClassDef(self, original_node, updated_node): - new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name) - return updated_node.with_changes(name=cst.Name(new_name)) + @m.leave(m.SimpleString() | m.Comment()) + def replace_name(self, original_node, updated_node): + return self._replace_name(original_node, updated_node) + + def leave_Name(self, original_node, updated_node): + if not self.only_doc: + return self._replace_name(original_node, updated_node) + return updated_node + + def leave_ImportFrom(self, original_node, updated_node): + """The imports from other file types (configuration, processing etc) should use original model name.""" + if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()): + patterns = "|".join(ALL_FILE_TYPES) + regex = rf"({patterns})_{self.new_name}" + new_source = re.sub( + regex, lambda m: f"{m.group(1)}_{self.original_new_model_name}", updated_node.module.value + ) + updated_node = updated_node.with_changes(module=updated_node.module.with_changes(value=new_source)) + return updated_node DOCSTRING_NODE = m.SimpleStatementLine( @@ -760,10 +784,12 @@ class ModelFileMapper(ModuleMapper): remaining_dependencies.remove(dep) relative_order[dep] = idx idx += 1 - # Add the class itself - remaining_dependencies.remove(class_name) - relative_order[class_name] = idx - idx += 1 + # Add the class itself (it can sometimes already be present if the order of classes in the source file + # does not make sense, i.e. a class is used somewhere before being defined like in `rt_detr`...) + if class_name in remaining_dependencies: + remaining_dependencies.remove(class_name) + relative_order[class_name] = idx + idx += 1 # Now add what still remains remaining_dependencies = tuple(remaining_dependencies) @@ -859,7 +885,24 @@ class ModelFileMapper(ModuleMapper): return mapper -def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str): +def common_partial_suffix(str1: str, str2: str) -> str: + """Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string, + we do not consider it a common suffix and return `""`""" + common_suffix = "" + for i in range(1, min(len(str1), len(str2)) + 1): + if str1[-i] == str2[-i]: + common_suffix = str1[-i] + common_suffix + else: + break + # We do not allow full string suffix + if common_suffix == str1 or common_suffix == str2: + common_suffix = "" + return common_suffix + + +def replace_class_node( + mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str +): """ Replace a class node which inherits from another modeling class. This function works in the following way: - start from the base class node of the inherited class (a cst.Node) @@ -889,6 +932,36 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}") original_node = mapper.classes[renamed_super_class] + # Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`) + new_name = class_node.name + + # If the new class name is different from the renamed super class name, we need to update the docstrings/comments accordingly + if new_name.value != renamed_super_class: + common_suffix = common_partial_suffix(new_name.value, renamed_super_class) + # Note that this works even without common prefix, in which case it does not replace anything + old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "") + temp_module = cst.Module(body=[original_node]) + original_node = temp_module.visit( + ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True) + ).body[0] + + # If we explicitly passed a new base with common suffix to an old base, it is for switching the prefix + # e.g. if the "natural" parent class is `PreTrainedModel` but we wanted to rename it to `PreTrainedVisionModel` + additional_bases = [base for base in all_bases if base != original_super_class] + new_bases = [] + for original_base in original_node.bases: + new_base = original_base + # we only potentially switch base for Name-based bases, not Attribute + if m.matches(original_base.value, m.Name()): + original_base_name = original_base.value.value + for additional_base_name in additional_bases: + suffix = common_partial_suffix(original_base_name, additional_base_name) + if len(suffix) > 0 and suffix[0].isupper(): + new_name_node = original_base.value.with_changes(value=additional_base_name) + new_base = original_base.with_changes(value=new_name_node) + break + new_bases.append(new_base) + original_methods = { f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in original_node.body.body @@ -942,12 +1015,17 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class! # Extract the original docstring updated_docstring = func.body[0].value.value - original_docstring = docstring_node[0].body[0].value.value - merged_doc = merge_docstrings(original_docstring, updated_docstring) - # Update the docstring in the original function - docstring_node = [ - docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))]) - ] + if len(docstring_node) == 0: # If the original docstring is empty, just create one from the updated. + docstring_node = [ + cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))]) + ] + else: + original_docstring = docstring_node[0].body[0].value.value + merged_doc = merge_docstrings(original_docstring, updated_docstring) + # Update the docstring in the original function + docstring_node = [ + docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))]) + ] if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef): end_meth.append(func) if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): @@ -970,10 +1048,10 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename # Use decorators redefined in `modular_xxx.py` if any new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators - # Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`) - name = class_node.name - return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name) + return original_node.with_changes( + body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name + ) TYPE_TO_FILE_TYPE = { @@ -1014,14 +1092,18 @@ VARIABLES_AT_THE_BEGINNING = ( IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",) -def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: list[cst.CSTNode]): - """Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`.""" +def append_new_import_node( + node: cst.CSTNode, unused_imports: set[str], added_names: set, imports_to_keep: list[cst.CSTNode] +): + """Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports` or `added_names`. + Also modifies `added_names` in-place accordingly.""" import_node = node.body[0] names_to_keep = [] for name in import_node.names: name_value = name.evaluated_name - if name_value not in unused_imports: + if name_value not in unused_imports and name_value not in added_names: names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT)) + added_names.add(name_value) if len(names_to_keep) > 0: new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) imports_to_keep.append(new_node) @@ -1036,40 +1118,38 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body)) scopes = set(wrapper.resolve(ScopeProvider).values()) unused_imports = set() - import_ref_count = {} + import_ref_count = defaultdict(lambda: 0) for scope in scopes: for assignment in scope.assignments: node = assignment.node if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): ref_count = len(assignment.references) name = assignment.name - # Similar imports may be redefined, and only used between their 1st and 2nd definition - # so if we already have a ref count > 0, the imports is actually used - if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys(): - unused_imports.add(name) - import_ref_count[name] = ref_count + import_ref_count[name] = max(ref_count, import_ref_count[name]) + # Similar imports may be redefined, and only used between their 1st and 2nd definition so if we already have + # a ref count > 0 at any point, the imports is actually used + unused_imports = {name for name, count in import_ref_count.items() if count <= 0 or name in body.keys()} imports_to_keep = [] + # We need to keep track of which names were already imported, because some import may be duplicated from multiple sources + # or be both protected and unprotected due to inconsistency between models + added_names = set() existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly for node in all_imports: if m.matches(node, m.If()): # handle safe imports new_statements = [] for stmt_node in node.body.body: - append_new_import_node(stmt_node, unused_imports, new_statements) + append_new_import_node(stmt_node, unused_imports, added_names, new_statements) new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements] if len(new_statements) > 0: new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) imports_to_keep.append(new_node) existing_protected_statements.update({str(stmt) for stmt in new_statements}) else: - append_new_import_node(node, unused_imports, imports_to_keep) + append_new_import_node(node, unused_imports, added_names, imports_to_keep) protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())] usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())] - # If the same import is both protected and unprotected, only keep the protected one - for protected_node in protected_import_nodes: - for stmt_node in protected_node.body.body: - usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]] # Protected imports always appear at the end of all imports return usual_import_nodes + protected_import_nodes @@ -1102,12 +1182,10 @@ class ModularFileMapper(ModuleMapper): Calling the method `create_modules()` after visit will create all modules based on this modular file. """ - def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None): + def __init__(self, python_module, new_name): super().__init__(python_module) # fmt: off self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3` - self.given_old_name = given_old_name - self.given_new_name = given_new_name self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"} self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module} @@ -1191,11 +1269,11 @@ class ModularFileMapper(ModuleMapper): # 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies self.visited_modules = {} self.renamers = {} + name_prefixes = self.infer_new_model_name() for file, module in self.model_specific_modules.items(): file_model_name = file.split(".")[-2] - renamer = ReplaceNameTransformer( - file_model_name, self.model_name, self.given_old_name, self.given_new_name - ) + new_name = name_prefixes[file] + renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name) renamed_module = module.visit(renamer) self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( renamed_module, @@ -1288,6 +1366,87 @@ class ModularFileMapper(ModuleMapper): return relative_order + def infer_new_model_name(self) -> dict: + """Infer whether we are using a model name prefix different from the usual model name as defined from the filename. + This is useful e.g. when we define a new multi-modal model, and only the text part inherits from `LlamaModel`, + so we have something like: + ```python + class NewModelNameTextDecoderLayer(LlamaDecoderLayer): + pass + ``` + with the `Text` prefix added to the model name. + However, in case of multiple prefix used, we raise a warning and use the most frequent prefix, to avoid parsing + the same file multiple times and inconsistencies in the objects added from dependencies. + If the new prefix collides with a prefix of another class in the file where we are importing from, then we also + raise a warning, and use the default prefix (model name) to avoid collisions in dependencies. + """ + prefix_model_name_mapping = defaultdict(Counter) + cased_default_name = get_cased_name(self.model_name) + # Iterate over all new classes to get modeling super classes + for class_name, class_node in self.classes.items(): + modeling_bases = [ + k.value.value for k in class_node.bases if k.value.value in self.model_specific_imported_objects + ] + if len(modeling_bases) > 1: + raise ValueError( + f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*modeling_bases,}." + ) + if len(modeling_bases) == 1: + filename = self.model_specific_imported_objects[modeling_bases[0]] + cased_model_name = cased_default_name # the default name prefix + suffix = common_partial_suffix(class_name, modeling_bases[0]) + if len(suffix) > 0 and suffix[0].isupper(): + cased_model_name = class_name.replace(suffix, "") + prefix_model_name_mapping[filename].update([cased_model_name]) + + # Check if we found multiple prefixes for some modeling files + final_name_mapping = {} + for file, prefixes_counter in prefix_model_name_mapping.items(): + if len(prefixes_counter) > 1: + _, total = prefixes_counter.most_common(1)[0] + most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total] + # if the default name is in the pool of equally used prefixes, use it, otherwise last encountered + final_name = cased_default_name if cased_default_name in most_used_entities else most_used_entities[-1] + else: + final_name = list(prefixes_counter)[0] + # Check if the prefix can be used without collisions in the names + old_cased_model_name = get_cased_name(file.split(".")[-2]) + old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name) + # Raise adequate warning depending on the situation + has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file) + if final_name != cased_default_name and has_prefix_collision: + if len(prefixes_counter) > 1: + logger.warning( + f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. However, the " + f"most used one, '{final_name}', is already present in the source file and will likely cause consistency " + f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args " + "and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different " + f"from '{cased_default_name}') or use a single prefix in all the modular (best)." + ) + else: + logger.warning( + f"We detected the use of the new default prefix {final_name} when inheriting from {file}. However, it is " + "already present in the source file and will likely cause consistency issues. For this reason we fallback " + f"to the default prefix '{cased_default_name}' when grabbing args and dependencies. Make sure to subclass " + f"the intermediate classes with the prefix you want (if different from '{cased_default_name}')" + ) + final_name = cased_default_name + elif len(prefixes_counter) > 1: + logger.warning( + f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. We will only " + f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the " + f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix " + "in all the modular (best)." + ) + final_name_mapping[file] = get_lowercase_name(final_name) + + # Check we are not missing imported files + for file in self.model_specific_modules.keys(): + if file not in final_name_mapping.keys(): + final_name_mapping[file] = self.model_name + + return final_name_mapping + def check_dependencies_and_create_import_node( file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str @@ -1338,11 +1497,11 @@ def get_class_node_and_dependencies( class node based on the inherited classes if needed. Also returns any new imports of a new class defined in the modular that we nay need. """ - bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects] - if len(bases) > 1: - raise ValueError( - f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}." - ) + # An exception was already raised if this has len > 1 + model_specific_bases = [ + k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects + ] + super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None file_type = find_file_type(class_name) file_to_update = files[file_type] @@ -1352,19 +1511,17 @@ def get_class_node_and_dependencies( imported_objects = modular_mapper.imported_objects_per_file[file_type] # We need to replace the class node with the transformers (modeling file) super class node - if len(bases) == 1: - super_class = bases[0] + if super_class is not None: super_file_name = modular_mapper.model_specific_imported_objects[super_class] # Get the mapper corresponding to the inherited class mapper = modular_mapper.visited_modules[super_file_name] # Rename the super class according to the exact same rule we used when renaming the whole module renamer = modular_mapper.renamers[super_file_name] - renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name) - renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name) + renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name) # Create the new class node - updated_node = replace_class_node(mapper, node, renamed_super_class) + updated_node = replace_class_node(mapper, node, renamed_super_class, super_class) # Grab all immediate dependencies of the new node new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects) @@ -1468,7 +1625,7 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: return files -def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None): +def convert_modular_file(modular_file): pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file) output = {} if pattern is not None: @@ -1478,8 +1635,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, code = file.read() module = cst.parse_module(code) wrapper = MetadataWrapper(module) - if cst_transformers is None: - cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name) + cst_transformers = ModularFileMapper(module, model_name) wrapper.visit(cst_transformers) for file, module in create_modules(cst_transformers).items(): if module != {}: @@ -1522,20 +1678,10 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/starcoder2/modular_starcoder2.py"], + default=["src/transformers/models/gemma/modular_gemma.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) - parser.add_argument( - "--old_model_name", - required=False, - help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file", - ) - parser.add_argument( - "--new_model_name", - required=False, - help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file", - ) args = parser.parse_args() if args.files_to_parse == ["all"]: args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) @@ -1544,5 +1690,5 @@ if __name__ == "__main__": for file_name in find_priority_list(args.files_to_parse): print(f"Converting {file_name} to a single model single file format") module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "") - converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name) + converted_files = convert_modular_file(file_name) converter = save_modeling_file(file_name, converted_files)