Support for easier multimodal use of modular (#35056)

* update modular and add examples

* style

* improve example comments

* style

* fix small logic issue for imports

* fix relative order issue when files do not make sense

* Improve comments

* trigger CIs
This commit is contained in:
Cyril Vallez 2024-12-04 15:13:11 +01:00 committed by GitHub
parent 46df859975
commit 1da1e0d7f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2424 additions and 103 deletions

View File

@ -36,7 +36,7 @@ logger = logging.get_logger(__name__)
class ImgprocModelImageProcessor(BaseImageProcessor):
r"""
Constructs a NEW_IMGPROC_MODEL image processor.
Constructs a IMGPROC_MODEL image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):

View File

@ -0,0 +1,357 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from examples/modular-transformers/modular_from_uppercase_model.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_from_uppercase_model.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Optional, Tuple
import torch
from torch import nn
from ...activations import ACT2FN
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from .configuration_from_uppercase_model import FromUppercaseModelConfig
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
class FromUppercaseModelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
class FromUppercaseModelFlashAttention2(FromUppercaseModelAttention):
"""
FromUppercaseModelAttention flash attention module. This module inherits from `FromUppercaseModelAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=causal_attention_mask is not None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class FromUppercaseModelSdpaAttention(FromUppercaseModelAttention):
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`FromUppercaseModelAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from FromUppercaseModelAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"FromUppercaseModelModel is using FromUppercaseModelSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
'be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask`
if attention_mask is not None and causal_attention_mask is not None:
attn_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attn_mask = causal_attention_mask
else:
attn_mask = attention_mask
bsz, tgt_len, embed_dim = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask` sequentially.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
scale=self.scale,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
class FromUppercaseModelMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
FROM_UPPERCASE_MODEL_ATTENTION_CLASSES = {
"eager": FromUppercaseModelAttention,
"sdpa": FromUppercaseModelSdpaAttention,
"flash_attention_2": FromUppercaseModelFlashAttention2,
}
class FromUppercaseModelEncoderLayer(nn.Module):
def __init__(self, config: FromUppercaseModelConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = FROM_UPPERCASE_MODEL_ATTENTION_CLASSES[config._attn_implementation](config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = FromUppercaseModelMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,705 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from examples/modular-transformers/modular_multimodal2.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_multimodal2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers.utils import add_start_docstrings
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
from ...utils import (
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_multimodal2 import Multimodal2Config, Multimodal2VisionConfig
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
class Multimodal2VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
class Multimodal2VisionSdpaAttention(Multimodal2VisionAttention):
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Multimodal2VisionAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Multimodal2VisionAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Multimodal2VisionModel is using Multimodal2VisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
'be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
# MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask`
if attention_mask is not None and causal_attention_mask is not None:
attn_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attn_mask = causal_attention_mask
else:
attn_mask = attention_mask
bsz, tgt_len, embed_dim = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask` sequentially.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
scale=self.scale,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
class Multimodal2VisionFlashAttention2(Multimodal2VisionAttention):
"""
Multimodal2VisionAttention flash attention module. This module inherits from `Multimodal2VisionAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=causal_attention_mask is not None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class Multimodal2VisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
MULTIMODAL2_VISION_ATTENTION_CLASSES = {
"eager": Multimodal2VisionAttention,
"sdpa": Multimodal2VisionSdpaAttention,
"flash_attention_2": Multimodal2VisionFlashAttention2,
}
class Multimodal2VisionEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MULTIMODAL2_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Multimodal2VisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class Multimodal2VisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Multimodal2VisionEncoderLayer`].
Args:
config: Multimodal2VisionConfig
"""
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class Multimodal2VisionEmbeddings(nn.Module):
def __init__(self, config: Multimodal2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1] - 1
position_embedding = self.position_embedding.weight.unsqueeze(0)
num_positions = position_embedding.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embedding(self.position_ids)
class_pos_embed = position_embedding[:, :1]
patch_pos_embed = position_embedding[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
)
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
MULTIMODAL2_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`Multimodal2ImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class Multimodal2VisionTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = Multimodal2VisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = Multimodal2VisionEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class Multimodal2VisionPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Multimodal2Config
base_model_prefix = "multimodal2_vision"
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Multimodal2VisionMLP):
pass
MULTIMODAL2_VISION_START_DOCSTRING = "doc"
@add_start_docstrings("New doc", MULTIMODAL2_VISION_START_DOCSTRING)
class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
config_class = Multimodal2VisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["Multimodal2VisionEncoderLayer"]
def __init__(self, config: Multimodal2VisionConfig):
super().__init__(config)
self.vision_model = Multimodal2VisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Multimodal2VisionModel
>>> model = Multimodal2VisionModel.from_pretrained("openai/multimodal2-vit-base-patch32")
>>> processor = AutoProcessor.from_pretrained("openai/multimodal2-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)

View File

@ -265,7 +265,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
min_dtype = torch.finfo(dtype).min
sequence_length = inputs_embeds.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
@ -358,9 +358,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration
>>> from transformers import AutoProcessor, NewTaskModelForNewTask
>>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/NewTaskModel-test-224px-hf")
>>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"

View File

@ -0,0 +1,6 @@
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
# Check if we can correctly grab dependencies with correct naming from all UPPERCASE old model
class FromUppercaseModelEncoderLayer(CLIPEncoderLayer):
pass

View File

@ -0,0 +1,6 @@
from transformers.models.llama.modeling_llama import LlamaModel
# Check that we can correctly change the prefix (here add Text part at the end of the name)
class Multimodal1TextModel(LlamaModel):
pass

View File

@ -0,0 +1,88 @@
"""
Here, because clip is not consistent with the use of the "Text" and "Vision" prefixes, we cannot simply use
```
class Multimodal2VisionModel(CLIPVisionModel):
pass
```
with the hope that all dependencies will be renamed as `Multimodal2VisionClass`. For this reason, if we want consistency and
use the "Vision" part everywhere, we need to overwrite the intermediate classes and add the prefix everytime.
This adds noise to the modular, but is unfortunately unavoidable.
"""
from torch import nn
from transformers.models.clip.modeling_clip import (
CLIPMLP,
CLIPAttention,
CLIPEncoder,
CLIPEncoderLayer,
CLIPFlashAttention2,
CLIPPreTrainedModel,
CLIPSdpaAttention,
CLIPVisionModel,
CLIPVisionTransformer,
)
from transformers.utils import add_start_docstrings
class Multimodal2VisionAttention(CLIPAttention):
pass
# Check that adding the second base class correctly set the parent, even though in clip it does not have the "Vision" part
class Multimodal2VisionSdpaAttention(CLIPSdpaAttention, Multimodal2VisionAttention):
pass
# Check that adding the second base class correctly set the parent, even though in clip it does not have the "Vision" part
class Multimodal2VisionFlashAttention2(CLIPFlashAttention2, Multimodal2VisionAttention):
pass
MULTIMODAL2_VISION_ATTENTION_CLASSES = {
"eager": Multimodal2VisionAttention,
"sdpa": Multimodal2VisionSdpaAttention,
"flash_attention_2": Multimodal2VisionFlashAttention2,
}
class Multimodal2VisionMLP(CLIPMLP):
pass
class Multimodal2VisionEncoderLayer(CLIPEncoderLayer):
def __init__(self, config):
super().__init__()
self.self_attn = MULTIMODAL2_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = Multimodal2VisionMLP(config)
class Multimodal2VisionEncoder(CLIPEncoder):
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
# Finally here the `Vision` part was correct in CLIP, but we still need to tell it that the encoder arg should use it as well
class Multimodal2VisionTransformer(CLIPVisionTransformer):
def __init__(self, config):
super().__init__(config)
self.encoder = Multimodal2VisionEncoder(config)
class Multimodal2VisionPreTrainedModel(CLIPPreTrainedModel):
def _init_weights(self, module):
if isinstance(module, Multimodal2VisionMLP):
pass
MULTIMODAL2_VISION_START_DOCSTRING = "doc"
# Here the only arg `self.vision_model = CLIPVisionTransformer(config)` in CLIPVisionModel already has the "Vision" part, so
# no need to overwrite it, it will look for `Multimodal2VisionTransformer` which has already being redefined above
# Note: we may want to redefine decorator as well for full consistency, as CLIP does not use "CLIP_VISION_START_DOCSTRING" but only
# "CLIP_START_DOCSTRING"
@add_start_docstrings("New doc", MULTIMODAL2_VISION_START_DOCSTRING)
class Multimodal2VisionModel(CLIPVisionModel, Multimodal2VisionPreTrainedModel):
_no_split_modules = ["Multimodal2VisionEncoderLayer"]

View File

@ -20,7 +20,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig

View File

@ -27,7 +27,6 @@ import torch.nn as nn
from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,

View File

@ -14,7 +14,6 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import (

View File

@ -34,7 +34,6 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,

View File

@ -18,7 +18,7 @@ import importlib
import os
import re
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections import Counter, defaultdict, deque
from typing import Dict, Set
import libcst as cst
@ -48,7 +48,7 @@ def get_module_source_from_name(module_name: str) -> str:
# Extract the source code from the module name
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
return f"Module {module_name} not found"
raise ValueError(f"Cannot open file associated with {module_name} module.")
with open(spec.origin, "r", encoding="utf-8") as file:
source_code = file.read()
@ -58,20 +58,40 @@ def get_module_source_from_name(module_name: str) -> str:
def preserve_case_replace(text, patterns: dict, default_name: str):
# Create a regex pattern to match all variations
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)
compiled_regex = re.compile(f"({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)
def replace(match):
word = match.group(0)
result = patterns.get(word, default_name)
return result
matched_pattern = match.group(1)
next_char = match.group(2)
new_pattern = patterns.get(matched_pattern, default_name)
# In this case, the cased old model did not respect CamelCase and was all UPPERCASE, so we need to rely on next char
# The heuristic is: if next char is not a letter, then it is not part of a model name and result should be `new_name`.upper()
if len(patterns) == 2 and matched_pattern.isupper():
if not next_char.isalpha():
# `new_name.upper()` is just the other entry for `matched_pattern.lower()`, uppercased
new_pattern = patterns[matched_pattern.lower()].upper()
return new_pattern + next_char
return compiled_regex.sub(replace, text)
def convert_to_camelcase(text, old_name: str, default_old_name: str):
# Regex pattern to match consecutive uppercase letters and lowercase the first set
result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1)
return result
def get_cased_name(lowercase_name: str) -> str:
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`."""
if lowercase_name in CONFIG_MAPPING_NAMES:
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "")
else:
return "".join(x.title() for x in lowercase_name.split("_"))
def get_lowercase_name(cased_name: str) -> str:
"""From a model name in Camelcase in the format `MyModel`, return the lowercase name in the format `my_model`."""
inverse_mapping = {value: key for key, value in CONFIG_MAPPING_NAMES.items()}
if cased_name + "Config" in inverse_mapping:
return inverse_mapping[cased_name + "Config"]
else:
return "_".join([s.lower() for s in re.findall(r"[A-Z][^A-Z]*", cased_name)])
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
@ -84,43 +104,47 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
- LLaMa -> MyNewModel abd MyNewModel -> Llama
"""
def __init__(
self,
old_name,
new_name,
given_old_name=None,
given_new_name=None,
):
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False):
super().__init__()
self.old_name = old_name
self.new_name = new_name
self.default_name = "".join(x.title() for x in new_name.split("_"))
if self.new_name in CONFIG_MAPPING_NAMES:
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
"Config", ""
) # the best source of truth for class names. Could also just use the ones de
self.cased_new_name = get_cased_name(self.new_name)
self.cased_old_name = get_cased_name(self.old_name)
self.patterns = {
old_name: new_name,
old_name.upper(): new_name.upper(),
"".join(x.title() for x in old_name.split("_")): self.default_name,
# For some old models, `self.cased_old_name` == `old_name.upper()` in which case this overwrite previous entry
self.cased_old_name: self.cased_new_name,
}
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
self.patterns[given_old_name] = given_new_name
if self.old_name in CONFIG_MAPPING_NAMES:
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
if self.default_old_name.isupper():
self.default_old_name = self.default_old_name.capitalize()
# In case new_name is a prefix alias, and not the original new model name
self.original_new_model_name = original_new_model_name
self.only_doc = only_doc
@m.leave(m.Name() | m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
def _replace_name(self, original_node, updated_node):
if re.findall(r"# Copied from", updated_node.value):
return cst.RemoveFromParent()
update = preserve_case_replace(updated_node.value, self.patterns, self.default_name)
update = preserve_case_replace(updated_node.value, self.patterns, self.cased_new_name)
return updated_node.with_changes(value=update)
def leave_ClassDef(self, original_node, updated_node):
new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name)
return updated_node.with_changes(name=cst.Name(new_name))
@m.leave(m.SimpleString() | m.Comment())
def replace_name(self, original_node, updated_node):
return self._replace_name(original_node, updated_node)
def leave_Name(self, original_node, updated_node):
if not self.only_doc:
return self._replace_name(original_node, updated_node)
return updated_node
def leave_ImportFrom(self, original_node, updated_node):
"""The imports from other file types (configuration, processing etc) should use original model name."""
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()):
patterns = "|".join(ALL_FILE_TYPES)
regex = rf"({patterns})_{self.new_name}"
new_source = re.sub(
regex, lambda m: f"{m.group(1)}_{self.original_new_model_name}", updated_node.module.value
)
updated_node = updated_node.with_changes(module=updated_node.module.with_changes(value=new_source))
return updated_node
DOCSTRING_NODE = m.SimpleStatementLine(
@ -760,10 +784,12 @@ class ModelFileMapper(ModuleMapper):
remaining_dependencies.remove(dep)
relative_order[dep] = idx
idx += 1
# Add the class itself
remaining_dependencies.remove(class_name)
relative_order[class_name] = idx
idx += 1
# Add the class itself (it can sometimes already be present if the order of classes in the source file
# does not make sense, i.e. a class is used somewhere before being defined like in `rt_detr`...)
if class_name in remaining_dependencies:
remaining_dependencies.remove(class_name)
relative_order[class_name] = idx
idx += 1
# Now add what still remains
remaining_dependencies = tuple(remaining_dependencies)
@ -859,7 +885,24 @@ class ModelFileMapper(ModuleMapper):
return mapper
def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str):
def common_partial_suffix(str1: str, str2: str) -> str:
"""Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string,
we do not consider it a common suffix and return `""`"""
common_suffix = ""
for i in range(1, min(len(str1), len(str2)) + 1):
if str1[-i] == str2[-i]:
common_suffix = str1[-i] + common_suffix
else:
break
# We do not allow full string suffix
if common_suffix == str1 or common_suffix == str2:
common_suffix = ""
return common_suffix
def replace_class_node(
mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
):
"""
Replace a class node which inherits from another modeling class. This function works in the following way:
- start from the base class node of the inherited class (a cst.Node)
@ -889,6 +932,36 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
original_node = mapper.classes[renamed_super_class]
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
new_name = class_node.name
# If the new class name is different from the renamed super class name, we need to update the docstrings/comments accordingly
if new_name.value != renamed_super_class:
common_suffix = common_partial_suffix(new_name.value, renamed_super_class)
# Note that this works even without common prefix, in which case it does not replace anything
old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "")
temp_module = cst.Module(body=[original_node])
original_node = temp_module.visit(
ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True)
).body[0]
# If we explicitly passed a new base with common suffix to an old base, it is for switching the prefix
# e.g. if the "natural" parent class is `PreTrainedModel` but we wanted to rename it to `PreTrainedVisionModel`
additional_bases = [base for base in all_bases if base != original_super_class]
new_bases = []
for original_base in original_node.bases:
new_base = original_base
# we only potentially switch base for Name-based bases, not Attribute
if m.matches(original_base.value, m.Name()):
original_base_name = original_base.value.value
for additional_base_name in additional_bases:
suffix = common_partial_suffix(original_base_name, additional_base_name)
if len(suffix) > 0 and suffix[0].isupper():
new_name_node = original_base.value.with_changes(value=additional_base_name)
new_base = original_base.with_changes(value=new_name_node)
break
new_bases.append(new_base)
original_methods = {
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
for f in original_node.body.body
@ -942,12 +1015,17 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
# Extract the original docstring
updated_docstring = func.body[0].value.value
original_docstring = docstring_node[0].body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
# Update the docstring in the original function
docstring_node = [
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
]
if len(docstring_node) == 0: # If the original docstring is empty, just create one from the updated.
docstring_node = [
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))])
]
else:
original_docstring = docstring_node[0].body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
# Update the docstring in the original function
docstring_node = [
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
]
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
end_meth.append(func)
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
@ -970,10 +1048,10 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
# Use decorators redefined in `modular_xxx.py` if any
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
name = class_node.name
return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name)
return original_node.with_changes(
body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name
)
TYPE_TO_FILE_TYPE = {
@ -1014,14 +1092,18 @@ VARIABLES_AT_THE_BEGINNING = (
IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",)
def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: list[cst.CSTNode]):
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`."""
def append_new_import_node(
node: cst.CSTNode, unused_imports: set[str], added_names: set, imports_to_keep: list[cst.CSTNode]
):
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports` or `added_names`.
Also modifies `added_names` in-place accordingly."""
import_node = node.body[0]
names_to_keep = []
for name in import_node.names:
name_value = name.evaluated_name
if name_value not in unused_imports:
if name_value not in unused_imports and name_value not in added_names:
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
added_names.add(name_value)
if len(names_to_keep) > 0:
new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)])
imports_to_keep.append(new_node)
@ -1036,40 +1118,38 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body))
scopes = set(wrapper.resolve(ScopeProvider).values())
unused_imports = set()
import_ref_count = {}
import_ref_count = defaultdict(lambda: 0)
for scope in scopes:
for assignment in scope.assignments:
node = assignment.node
if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)):
ref_count = len(assignment.references)
name = assignment.name
# Similar imports may be redefined, and only used between their 1st and 2nd definition
# so if we already have a ref count > 0, the imports is actually used
if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys():
unused_imports.add(name)
import_ref_count[name] = ref_count
import_ref_count[name] = max(ref_count, import_ref_count[name])
# Similar imports may be redefined, and only used between their 1st and 2nd definition so if we already have
# a ref count > 0 at any point, the imports is actually used
unused_imports = {name for name, count in import_ref_count.items() if count <= 0 or name in body.keys()}
imports_to_keep = []
# We need to keep track of which names were already imported, because some import may be duplicated from multiple sources
# or be both protected and unprotected due to inconsistency between models
added_names = set()
existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly
for node in all_imports:
if m.matches(node, m.If()): # handle safe imports
new_statements = []
for stmt_node in node.body.body:
append_new_import_node(stmt_node, unused_imports, new_statements)
append_new_import_node(stmt_node, unused_imports, added_names, new_statements)
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements]
if len(new_statements) > 0:
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
imports_to_keep.append(new_node)
existing_protected_statements.update({str(stmt) for stmt in new_statements})
else:
append_new_import_node(node, unused_imports, imports_to_keep)
append_new_import_node(node, unused_imports, added_names, imports_to_keep)
protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())]
usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())]
# If the same import is both protected and unprotected, only keep the protected one
for protected_node in protected_import_nodes:
for stmt_node in protected_node.body.body:
usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]]
# Protected imports always appear at the end of all imports
return usual_import_nodes + protected_import_nodes
@ -1102,12 +1182,10 @@ class ModularFileMapper(ModuleMapper):
Calling the method `create_modules()` after visit will create all modules based on this modular file.
"""
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
def __init__(self, python_module, new_name):
super().__init__(python_module)
# fmt: off
self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3`
self.given_old_name = given_old_name
self.given_new_name = given_new_name
self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"}
self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module}
@ -1191,11 +1269,11 @@ class ModularFileMapper(ModuleMapper):
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
self.visited_modules = {}
self.renamers = {}
name_prefixes = self.infer_new_model_name()
for file, module in self.model_specific_modules.items():
file_model_name = file.split(".")[-2]
renamer = ReplaceNameTransformer(
file_model_name, self.model_name, self.given_old_name, self.given_new_name
)
new_name = name_prefixes[file]
renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name)
renamed_module = module.visit(renamer)
self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies(
renamed_module,
@ -1288,6 +1366,87 @@ class ModularFileMapper(ModuleMapper):
return relative_order
def infer_new_model_name(self) -> dict:
"""Infer whether we are using a model name prefix different from the usual model name as defined from the filename.
This is useful e.g. when we define a new multi-modal model, and only the text part inherits from `LlamaModel`,
so we have something like:
```python
class NewModelNameTextDecoderLayer(LlamaDecoderLayer):
pass
```
with the `Text` prefix added to the model name.
However, in case of multiple prefix used, we raise a warning and use the most frequent prefix, to avoid parsing
the same file multiple times and inconsistencies in the objects added from dependencies.
If the new prefix collides with a prefix of another class in the file where we are importing from, then we also
raise a warning, and use the default prefix (model name) to avoid collisions in dependencies.
"""
prefix_model_name_mapping = defaultdict(Counter)
cased_default_name = get_cased_name(self.model_name)
# Iterate over all new classes to get modeling super classes
for class_name, class_node in self.classes.items():
modeling_bases = [
k.value.value for k in class_node.bases if k.value.value in self.model_specific_imported_objects
]
if len(modeling_bases) > 1:
raise ValueError(
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*modeling_bases,}."
)
if len(modeling_bases) == 1:
filename = self.model_specific_imported_objects[modeling_bases[0]]
cased_model_name = cased_default_name # the default name prefix
suffix = common_partial_suffix(class_name, modeling_bases[0])
if len(suffix) > 0 and suffix[0].isupper():
cased_model_name = class_name.replace(suffix, "")
prefix_model_name_mapping[filename].update([cased_model_name])
# Check if we found multiple prefixes for some modeling files
final_name_mapping = {}
for file, prefixes_counter in prefix_model_name_mapping.items():
if len(prefixes_counter) > 1:
_, total = prefixes_counter.most_common(1)[0]
most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total]
# if the default name is in the pool of equally used prefixes, use it, otherwise last encountered
final_name = cased_default_name if cased_default_name in most_used_entities else most_used_entities[-1]
else:
final_name = list(prefixes_counter)[0]
# Check if the prefix can be used without collisions in the names
old_cased_model_name = get_cased_name(file.split(".")[-2])
old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name)
# Raise adequate warning depending on the situation
has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file)
if final_name != cased_default_name and has_prefix_collision:
if len(prefixes_counter) > 1:
logger.warning(
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. However, the "
f"most used one, '{final_name}', is already present in the source file and will likely cause consistency "
f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args "
"and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different "
f"from '{cased_default_name}') or use a single prefix in all the modular (best)."
)
else:
logger.warning(
f"We detected the use of the new default prefix {final_name} when inheriting from {file}. However, it is "
"already present in the source file and will likely cause consistency issues. For this reason we fallback "
f"to the default prefix '{cased_default_name}' when grabbing args and dependencies. Make sure to subclass "
f"the intermediate classes with the prefix you want (if different from '{cased_default_name}')"
)
final_name = cased_default_name
elif len(prefixes_counter) > 1:
logger.warning(
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. We will only "
f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the "
f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix "
"in all the modular (best)."
)
final_name_mapping[file] = get_lowercase_name(final_name)
# Check we are not missing imported files
for file in self.model_specific_modules.keys():
if file not in final_name_mapping.keys():
final_name_mapping[file] = self.model_name
return final_name_mapping
def check_dependencies_and_create_import_node(
file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str
@ -1338,11 +1497,11 @@ def get_class_node_and_dependencies(
class node based on the inherited classes if needed. Also returns any new imports of a new class defined in
the modular that we nay need.
"""
bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects]
if len(bases) > 1:
raise ValueError(
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}."
)
# An exception was already raised if this has len > 1
model_specific_bases = [
k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects
]
super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None
file_type = find_file_type(class_name)
file_to_update = files[file_type]
@ -1352,19 +1511,17 @@ def get_class_node_and_dependencies(
imported_objects = modular_mapper.imported_objects_per_file[file_type]
# We need to replace the class node with the transformers (modeling file) super class node
if len(bases) == 1:
super_class = bases[0]
if super_class is not None:
super_file_name = modular_mapper.model_specific_imported_objects[super_class]
# Get the mapper corresponding to the inherited class
mapper = modular_mapper.visited_modules[super_file_name]
# Rename the super class according to the exact same rule we used when renaming the whole module
renamer = modular_mapper.renamers[super_file_name]
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name)
renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name)
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name)
# Create the new class node
updated_node = replace_class_node(mapper, node, renamed_super_class)
updated_node = replace_class_node(mapper, node, renamed_super_class, super_class)
# Grab all immediate dependencies of the new node
new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects)
@ -1468,7 +1625,7 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]:
return files
def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None):
def convert_modular_file(modular_file):
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
output = {}
if pattern is not None:
@ -1478,8 +1635,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
if cst_transformers is None:
cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name)
cst_transformers = ModularFileMapper(module, model_name)
wrapper.visit(cst_transformers)
for file, module in create_modules(cst_transformers).items():
if module != {}:
@ -1522,20 +1678,10 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/starcoder2/modular_starcoder2.py"],
default=["src/transformers/models/gemma/modular_gemma.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)
parser.add_argument(
"--old_model_name",
required=False,
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file",
)
parser.add_argument(
"--new_model_name",
required=False,
help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file",
)
args = parser.parse_args()
if args.files_to_parse == ["all"]:
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
@ -1544,5 +1690,5 @@ if __name__ == "__main__":
for file_name in find_priority_list(args.files_to_parse):
print(f"Converting {file_name} to a single model single file format")
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
converted_files = convert_modular_file(file_name)
converter = save_modeling_file(file_name, converted_files)