mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Support for easier multimodal use of modular (#35056)
* update modular and add examples * style * improve example comments * style * fix small logic issue for imports * fix relative order issue when files do not make sense * Improve comments * trigger CIs
This commit is contained in:
parent
46df859975
commit
1da1e0d7f2
@ -36,7 +36,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class ImgprocModelImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a NEW_IMGPROC_MODEL image processor.
|
||||
Constructs a IMGPROC_MODEL image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
|
357
examples/modular-transformers/modeling_from_uppercase_model.py
Normal file
357
examples/modular-transformers/modeling_from_uppercase_model.py
Normal file
@ -0,0 +1,357 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_from_uppercase_model.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_from_uppercase_model.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
|
||||
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
|
||||
from .configuration_from_uppercase_model import FromUppercaseModelConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FromUppercaseModelAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
|
||||
|
||||
class FromUppercaseModelFlashAttention2(FromUppercaseModelAttention):
|
||||
"""
|
||||
FromUppercaseModelAttention flash attention module. This module inherits from `FromUppercaseModelAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
output_attentions = False
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
dropout_rate = self.dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32.
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
is_causal=causal_attention_mask is not None,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class FromUppercaseModelSdpaAttention(FromUppercaseModelAttention):
|
||||
"""
|
||||
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`FromUppercaseModelAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from FromUppercaseModelAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"FromUppercaseModelModel is using FromUppercaseModelSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
|
||||
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
|
||||
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
|
||||
'be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask`
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attn_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attn_mask = causal_attention_mask
|
||||
else:
|
||||
attn_mask = attention_mask
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask` sequentially.
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class FromUppercaseModelMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
FROM_UPPERCASE_MODEL_ATTENTION_CLASSES = {
|
||||
"eager": FromUppercaseModelAttention,
|
||||
"sdpa": FromUppercaseModelSdpaAttention,
|
||||
"flash_attention_2": FromUppercaseModelFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class FromUppercaseModelEncoderLayer(nn.Module):
|
||||
def __init__(self, config: FromUppercaseModelConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = FROM_UPPERCASE_MODEL_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = FromUppercaseModelMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
causal_attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(config.encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
1017
examples/modular-transformers/modeling_multimodal1.py
Normal file
1017
examples/modular-transformers/modeling_multimodal1.py
Normal file
File diff suppressed because it is too large
Load Diff
705
examples/modular-transformers/modeling_multimodal2.py
Normal file
705
examples/modular-transformers/modeling_multimodal2.py
Normal file
@ -0,0 +1,705 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_multimodal2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_multimodal2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.utils import add_start_docstrings
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
|
||||
from ...utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_multimodal2 import Multimodal2Config, Multimodal2VisionConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Multimodal2VisionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
|
||||
|
||||
class Multimodal2VisionSdpaAttention(Multimodal2VisionAttention):
|
||||
"""
|
||||
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`Multimodal2VisionAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from Multimodal2VisionAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Multimodal2VisionModel is using Multimodal2VisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
|
||||
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
|
||||
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
|
||||
'be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask`
|
||||
if attention_mask is not None and causal_attention_mask is not None:
|
||||
attn_mask = attention_mask + causal_attention_mask
|
||||
elif causal_attention_mask is not None:
|
||||
attn_mask = causal_attention_mask
|
||||
else:
|
||||
attn_mask = attention_mask
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask` sequentially.
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class Multimodal2VisionFlashAttention2(Multimodal2VisionAttention):
|
||||
"""
|
||||
Multimodal2VisionAttention flash attention module. This module inherits from `Multimodal2VisionAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
output_attentions = False
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
dropout_rate = self.dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32.
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
is_causal=causal_attention_mask is not None,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Multimodal2VisionMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
MULTIMODAL2_VISION_ATTENTION_CLASSES = {
|
||||
"eager": Multimodal2VisionAttention,
|
||||
"sdpa": Multimodal2VisionSdpaAttention,
|
||||
"flash_attention_2": Multimodal2VisionFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class Multimodal2VisionEncoderLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = MULTIMODAL2_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Multimodal2VisionMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
causal_attention_mask: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
`(config.encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Multimodal2VisionEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`Multimodal2VisionEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: Multimodal2VisionConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
class Multimodal2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Multimodal2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
position_embedding = self.position_embedding.weight.unsqueeze(0)
|
||||
num_positions = position_embedding.shape[1] - 1
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embedding(self.position_ids)
|
||||
|
||||
class_pos_embed = position_embedding[:, :1]
|
||||
patch_pos_embed = position_embedding[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
|
||||
)
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
MULTIMODAL2_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`Multimodal2ImageProcessor.__call__`] for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class Multimodal2VisionTransformer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = Multimodal2VisionEmbeddings(config)
|
||||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.encoder = Multimodal2VisionEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = False,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class Multimodal2VisionPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = Multimodal2Config
|
||||
base_model_prefix = "multimodal2_vision"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, Multimodal2VisionMLP):
|
||||
pass
|
||||
|
||||
|
||||
MULTIMODAL2_VISION_START_DOCSTRING = "doc"
|
||||
|
||||
|
||||
@add_start_docstrings("New doc", MULTIMODAL2_VISION_START_DOCSTRING)
|
||||
class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
|
||||
config_class = Multimodal2VisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["Multimodal2VisionEncoderLayer"]
|
||||
|
||||
def __init__(self, config: Multimodal2VisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_model = Multimodal2VisionTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
@add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Multimodal2VisionModel
|
||||
|
||||
>>> model = Multimodal2VisionModel.from_pretrained("openai/multimodal2-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/multimodal2-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
@ -265,7 +265,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = inputs_embeds.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_length()
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
@ -358,9 +358,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration
|
||||
>>> from transformers import AutoProcessor, NewTaskModelForNewTask
|
||||
|
||||
>>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||
>>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
|
@ -0,0 +1,6 @@
|
||||
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
|
||||
|
||||
|
||||
# Check if we can correctly grab dependencies with correct naming from all UPPERCASE old model
|
||||
class FromUppercaseModelEncoderLayer(CLIPEncoderLayer):
|
||||
pass
|
6
examples/modular-transformers/modular_multimodal1.py
Normal file
6
examples/modular-transformers/modular_multimodal1.py
Normal file
@ -0,0 +1,6 @@
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
|
||||
# Check that we can correctly change the prefix (here add Text part at the end of the name)
|
||||
class Multimodal1TextModel(LlamaModel):
|
||||
pass
|
88
examples/modular-transformers/modular_multimodal2.py
Normal file
88
examples/modular-transformers/modular_multimodal2.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""
|
||||
Here, because clip is not consistent with the use of the "Text" and "Vision" prefixes, we cannot simply use
|
||||
```
|
||||
class Multimodal2VisionModel(CLIPVisionModel):
|
||||
pass
|
||||
```
|
||||
with the hope that all dependencies will be renamed as `Multimodal2VisionClass`. For this reason, if we want consistency and
|
||||
use the "Vision" part everywhere, we need to overwrite the intermediate classes and add the prefix everytime.
|
||||
This adds noise to the modular, but is unfortunately unavoidable.
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
from transformers.models.clip.modeling_clip import (
|
||||
CLIPMLP,
|
||||
CLIPAttention,
|
||||
CLIPEncoder,
|
||||
CLIPEncoderLayer,
|
||||
CLIPFlashAttention2,
|
||||
CLIPPreTrainedModel,
|
||||
CLIPSdpaAttention,
|
||||
CLIPVisionModel,
|
||||
CLIPVisionTransformer,
|
||||
)
|
||||
from transformers.utils import add_start_docstrings
|
||||
|
||||
|
||||
class Multimodal2VisionAttention(CLIPAttention):
|
||||
pass
|
||||
|
||||
|
||||
# Check that adding the second base class correctly set the parent, even though in clip it does not have the "Vision" part
|
||||
class Multimodal2VisionSdpaAttention(CLIPSdpaAttention, Multimodal2VisionAttention):
|
||||
pass
|
||||
|
||||
|
||||
# Check that adding the second base class correctly set the parent, even though in clip it does not have the "Vision" part
|
||||
class Multimodal2VisionFlashAttention2(CLIPFlashAttention2, Multimodal2VisionAttention):
|
||||
pass
|
||||
|
||||
|
||||
MULTIMODAL2_VISION_ATTENTION_CLASSES = {
|
||||
"eager": Multimodal2VisionAttention,
|
||||
"sdpa": Multimodal2VisionSdpaAttention,
|
||||
"flash_attention_2": Multimodal2VisionFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class Multimodal2VisionMLP(CLIPMLP):
|
||||
pass
|
||||
|
||||
|
||||
class Multimodal2VisionEncoderLayer(CLIPEncoderLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self_attn = MULTIMODAL2_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
|
||||
self.mlp = Multimodal2VisionMLP(config)
|
||||
|
||||
|
||||
class Multimodal2VisionEncoder(CLIPEncoder):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
|
||||
# Finally here the `Vision` part was correct in CLIP, but we still need to tell it that the encoder arg should use it as well
|
||||
class Multimodal2VisionTransformer(CLIPVisionTransformer):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.encoder = Multimodal2VisionEncoder(config)
|
||||
|
||||
|
||||
class Multimodal2VisionPreTrainedModel(CLIPPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, Multimodal2VisionMLP):
|
||||
pass
|
||||
|
||||
|
||||
MULTIMODAL2_VISION_START_DOCSTRING = "doc"
|
||||
|
||||
|
||||
# Here the only arg `self.vision_model = CLIPVisionTransformer(config)` in CLIPVisionModel already has the "Vision" part, so
|
||||
# no need to overwrite it, it will look for `Multimodal2VisionTransformer` which has already being redefined above
|
||||
# Note: we may want to redefine decorator as well for full consistency, as CLIP does not use "CLIP_VISION_START_DOCSTRING" but only
|
||||
# "CLIP_START_DOCSTRING"
|
||||
@add_start_docstrings("New doc", MULTIMODAL2_VISION_START_DOCSTRING)
|
||||
class Multimodal2VisionModel(CLIPVisionModel, Multimodal2VisionPreTrainedModel):
|
||||
_no_split_modules = ["Multimodal2VisionEncoderLayer"]
|
@ -20,7 +20,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
|
@ -27,7 +27,6 @@ import torch.nn as nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, HybridCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -14,7 +14,6 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
|
@ -34,7 +34,6 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -18,7 +18,7 @@ import importlib
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections import Counter, defaultdict, deque
|
||||
from typing import Dict, Set
|
||||
|
||||
import libcst as cst
|
||||
@ -48,7 +48,7 @@ def get_module_source_from_name(module_name: str) -> str:
|
||||
# Extract the source code from the module name
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if spec is None or spec.origin is None:
|
||||
return f"Module {module_name} not found"
|
||||
raise ValueError(f"Cannot open file associated with {module_name} module.")
|
||||
|
||||
with open(spec.origin, "r", encoding="utf-8") as file:
|
||||
source_code = file.read()
|
||||
@ -58,20 +58,40 @@ def get_module_source_from_name(module_name: str) -> str:
|
||||
def preserve_case_replace(text, patterns: dict, default_name: str):
|
||||
# Create a regex pattern to match all variations
|
||||
regex_pattern = "|".join(re.escape(key) for key in patterns.keys())
|
||||
compiled_regex = re.compile(regex_pattern, re.IGNORECASE)
|
||||
compiled_regex = re.compile(f"({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
def replace(match):
|
||||
word = match.group(0)
|
||||
result = patterns.get(word, default_name)
|
||||
return result
|
||||
matched_pattern = match.group(1)
|
||||
next_char = match.group(2)
|
||||
new_pattern = patterns.get(matched_pattern, default_name)
|
||||
|
||||
# In this case, the cased old model did not respect CamelCase and was all UPPERCASE, so we need to rely on next char
|
||||
# The heuristic is: if next char is not a letter, then it is not part of a model name and result should be `new_name`.upper()
|
||||
if len(patterns) == 2 and matched_pattern.isupper():
|
||||
if not next_char.isalpha():
|
||||
# `new_name.upper()` is just the other entry for `matched_pattern.lower()`, uppercased
|
||||
new_pattern = patterns[matched_pattern.lower()].upper()
|
||||
|
||||
return new_pattern + next_char
|
||||
|
||||
return compiled_regex.sub(replace, text)
|
||||
|
||||
|
||||
def convert_to_camelcase(text, old_name: str, default_old_name: str):
|
||||
# Regex pattern to match consecutive uppercase letters and lowercase the first set
|
||||
result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1)
|
||||
return result
|
||||
def get_cased_name(lowercase_name: str) -> str:
|
||||
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`."""
|
||||
if lowercase_name in CONFIG_MAPPING_NAMES:
|
||||
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "")
|
||||
else:
|
||||
return "".join(x.title() for x in lowercase_name.split("_"))
|
||||
|
||||
|
||||
def get_lowercase_name(cased_name: str) -> str:
|
||||
"""From a model name in Camelcase in the format `MyModel`, return the lowercase name in the format `my_model`."""
|
||||
inverse_mapping = {value: key for key, value in CONFIG_MAPPING_NAMES.items()}
|
||||
if cased_name + "Config" in inverse_mapping:
|
||||
return inverse_mapping[cased_name + "Config"]
|
||||
else:
|
||||
return "_".join([s.lower() for s in re.findall(r"[A-Z][^A-Z]*", cased_name)])
|
||||
|
||||
|
||||
class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
@ -84,43 +104,47 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
- LLaMa -> MyNewModel abd MyNewModel -> Llama
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
old_name,
|
||||
new_name,
|
||||
given_old_name=None,
|
||||
given_new_name=None,
|
||||
):
|
||||
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False):
|
||||
super().__init__()
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
self.default_name = "".join(x.title() for x in new_name.split("_"))
|
||||
if self.new_name in CONFIG_MAPPING_NAMES:
|
||||
self.default_name = CONFIG_MAPPING_NAMES[self.new_name].replace(
|
||||
"Config", ""
|
||||
) # the best source of truth for class names. Could also just use the ones de
|
||||
self.cased_new_name = get_cased_name(self.new_name)
|
||||
self.cased_old_name = get_cased_name(self.old_name)
|
||||
self.patterns = {
|
||||
old_name: new_name,
|
||||
old_name.upper(): new_name.upper(),
|
||||
"".join(x.title() for x in old_name.split("_")): self.default_name,
|
||||
# For some old models, `self.cased_old_name` == `old_name.upper()` in which case this overwrite previous entry
|
||||
self.cased_old_name: self.cased_new_name,
|
||||
}
|
||||
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
|
||||
self.patterns[given_old_name] = given_new_name
|
||||
if self.old_name in CONFIG_MAPPING_NAMES:
|
||||
self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "")
|
||||
if self.default_old_name.isupper():
|
||||
self.default_old_name = self.default_old_name.capitalize()
|
||||
# In case new_name is a prefix alias, and not the original new model name
|
||||
self.original_new_model_name = original_new_model_name
|
||||
self.only_doc = only_doc
|
||||
|
||||
@m.leave(m.Name() | m.SimpleString() | m.Comment())
|
||||
def replace_name(self, original_node, updated_node):
|
||||
def _replace_name(self, original_node, updated_node):
|
||||
if re.findall(r"# Copied from", updated_node.value):
|
||||
return cst.RemoveFromParent()
|
||||
update = preserve_case_replace(updated_node.value, self.patterns, self.default_name)
|
||||
update = preserve_case_replace(updated_node.value, self.patterns, self.cased_new_name)
|
||||
return updated_node.with_changes(value=update)
|
||||
|
||||
def leave_ClassDef(self, original_node, updated_node):
|
||||
new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name)
|
||||
return updated_node.with_changes(name=cst.Name(new_name))
|
||||
@m.leave(m.SimpleString() | m.Comment())
|
||||
def replace_name(self, original_node, updated_node):
|
||||
return self._replace_name(original_node, updated_node)
|
||||
|
||||
def leave_Name(self, original_node, updated_node):
|
||||
if not self.only_doc:
|
||||
return self._replace_name(original_node, updated_node)
|
||||
return updated_node
|
||||
|
||||
def leave_ImportFrom(self, original_node, updated_node):
|
||||
"""The imports from other file types (configuration, processing etc) should use original model name."""
|
||||
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()):
|
||||
patterns = "|".join(ALL_FILE_TYPES)
|
||||
regex = rf"({patterns})_{self.new_name}"
|
||||
new_source = re.sub(
|
||||
regex, lambda m: f"{m.group(1)}_{self.original_new_model_name}", updated_node.module.value
|
||||
)
|
||||
updated_node = updated_node.with_changes(module=updated_node.module.with_changes(value=new_source))
|
||||
return updated_node
|
||||
|
||||
|
||||
DOCSTRING_NODE = m.SimpleStatementLine(
|
||||
@ -760,10 +784,12 @@ class ModelFileMapper(ModuleMapper):
|
||||
remaining_dependencies.remove(dep)
|
||||
relative_order[dep] = idx
|
||||
idx += 1
|
||||
# Add the class itself
|
||||
remaining_dependencies.remove(class_name)
|
||||
relative_order[class_name] = idx
|
||||
idx += 1
|
||||
# Add the class itself (it can sometimes already be present if the order of classes in the source file
|
||||
# does not make sense, i.e. a class is used somewhere before being defined like in `rt_detr`...)
|
||||
if class_name in remaining_dependencies:
|
||||
remaining_dependencies.remove(class_name)
|
||||
relative_order[class_name] = idx
|
||||
idx += 1
|
||||
|
||||
# Now add what still remains
|
||||
remaining_dependencies = tuple(remaining_dependencies)
|
||||
@ -859,7 +885,24 @@ class ModelFileMapper(ModuleMapper):
|
||||
return mapper
|
||||
|
||||
|
||||
def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str):
|
||||
def common_partial_suffix(str1: str, str2: str) -> str:
|
||||
"""Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string,
|
||||
we do not consider it a common suffix and return `""`"""
|
||||
common_suffix = ""
|
||||
for i in range(1, min(len(str1), len(str2)) + 1):
|
||||
if str1[-i] == str2[-i]:
|
||||
common_suffix = str1[-i] + common_suffix
|
||||
else:
|
||||
break
|
||||
# We do not allow full string suffix
|
||||
if common_suffix == str1 or common_suffix == str2:
|
||||
common_suffix = ""
|
||||
return common_suffix
|
||||
|
||||
|
||||
def replace_class_node(
|
||||
mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
|
||||
):
|
||||
"""
|
||||
Replace a class node which inherits from another modeling class. This function works in the following way:
|
||||
- start from the base class node of the inherited class (a cst.Node)
|
||||
@ -889,6 +932,36 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
|
||||
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
|
||||
|
||||
original_node = mapper.classes[renamed_super_class]
|
||||
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
|
||||
new_name = class_node.name
|
||||
|
||||
# If the new class name is different from the renamed super class name, we need to update the docstrings/comments accordingly
|
||||
if new_name.value != renamed_super_class:
|
||||
common_suffix = common_partial_suffix(new_name.value, renamed_super_class)
|
||||
# Note that this works even without common prefix, in which case it does not replace anything
|
||||
old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "")
|
||||
temp_module = cst.Module(body=[original_node])
|
||||
original_node = temp_module.visit(
|
||||
ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True)
|
||||
).body[0]
|
||||
|
||||
# If we explicitly passed a new base with common suffix to an old base, it is for switching the prefix
|
||||
# e.g. if the "natural" parent class is `PreTrainedModel` but we wanted to rename it to `PreTrainedVisionModel`
|
||||
additional_bases = [base for base in all_bases if base != original_super_class]
|
||||
new_bases = []
|
||||
for original_base in original_node.bases:
|
||||
new_base = original_base
|
||||
# we only potentially switch base for Name-based bases, not Attribute
|
||||
if m.matches(original_base.value, m.Name()):
|
||||
original_base_name = original_base.value.value
|
||||
for additional_base_name in additional_bases:
|
||||
suffix = common_partial_suffix(original_base_name, additional_base_name)
|
||||
if len(suffix) > 0 and suffix[0].isupper():
|
||||
new_name_node = original_base.value.with_changes(value=additional_base_name)
|
||||
new_base = original_base.with_changes(value=new_name_node)
|
||||
break
|
||||
new_bases.append(new_base)
|
||||
|
||||
original_methods = {
|
||||
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
|
||||
for f in original_node.body.body
|
||||
@ -942,12 +1015,17 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
|
||||
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
|
||||
# Extract the original docstring
|
||||
updated_docstring = func.body[0].value.value
|
||||
original_docstring = docstring_node[0].body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
# Update the docstring in the original function
|
||||
docstring_node = [
|
||||
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
|
||||
]
|
||||
if len(docstring_node) == 0: # If the original docstring is empty, just create one from the updated.
|
||||
docstring_node = [
|
||||
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))])
|
||||
]
|
||||
else:
|
||||
original_docstring = docstring_node[0].body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
# Update the docstring in the original function
|
||||
docstring_node = [
|
||||
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
|
||||
]
|
||||
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
|
||||
end_meth.append(func)
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
@ -970,10 +1048,10 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename
|
||||
|
||||
# Use decorators redefined in `modular_xxx.py` if any
|
||||
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
|
||||
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
|
||||
name = class_node.name
|
||||
|
||||
return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name)
|
||||
return original_node.with_changes(
|
||||
body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name
|
||||
)
|
||||
|
||||
|
||||
TYPE_TO_FILE_TYPE = {
|
||||
@ -1014,14 +1092,18 @@ VARIABLES_AT_THE_BEGINNING = (
|
||||
IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",)
|
||||
|
||||
|
||||
def append_new_import_node(node: cst.CSTNode, unused_imports: set[str], imports_to_keep: list[cst.CSTNode]):
|
||||
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports`."""
|
||||
def append_new_import_node(
|
||||
node: cst.CSTNode, unused_imports: set[str], added_names: set, imports_to_keep: list[cst.CSTNode]
|
||||
):
|
||||
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports` or `added_names`.
|
||||
Also modifies `added_names` in-place accordingly."""
|
||||
import_node = node.body[0]
|
||||
names_to_keep = []
|
||||
for name in import_node.names:
|
||||
name_value = name.evaluated_name
|
||||
if name_value not in unused_imports:
|
||||
if name_value not in unused_imports and name_value not in added_names:
|
||||
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT))
|
||||
added_names.add(name_value)
|
||||
if len(names_to_keep) > 0:
|
||||
new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)])
|
||||
imports_to_keep.append(new_node)
|
||||
@ -1036,40 +1118,38 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
|
||||
wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body))
|
||||
scopes = set(wrapper.resolve(ScopeProvider).values())
|
||||
unused_imports = set()
|
||||
import_ref_count = {}
|
||||
import_ref_count = defaultdict(lambda: 0)
|
||||
for scope in scopes:
|
||||
for assignment in scope.assignments:
|
||||
node = assignment.node
|
||||
if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)):
|
||||
ref_count = len(assignment.references)
|
||||
name = assignment.name
|
||||
# Similar imports may be redefined, and only used between their 1st and 2nd definition
|
||||
# so if we already have a ref count > 0, the imports is actually used
|
||||
if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys():
|
||||
unused_imports.add(name)
|
||||
import_ref_count[name] = ref_count
|
||||
import_ref_count[name] = max(ref_count, import_ref_count[name])
|
||||
# Similar imports may be redefined, and only used between their 1st and 2nd definition so if we already have
|
||||
# a ref count > 0 at any point, the imports is actually used
|
||||
unused_imports = {name for name, count in import_ref_count.items() if count <= 0 or name in body.keys()}
|
||||
|
||||
imports_to_keep = []
|
||||
# We need to keep track of which names were already imported, because some import may be duplicated from multiple sources
|
||||
# or be both protected and unprotected due to inconsistency between models
|
||||
added_names = set()
|
||||
existing_protected_statements = set() # str repr of the import nodes - does not work with the nodes directly
|
||||
for node in all_imports:
|
||||
if m.matches(node, m.If()): # handle safe imports
|
||||
new_statements = []
|
||||
for stmt_node in node.body.body:
|
||||
append_new_import_node(stmt_node, unused_imports, new_statements)
|
||||
append_new_import_node(stmt_node, unused_imports, added_names, new_statements)
|
||||
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements]
|
||||
if len(new_statements) > 0:
|
||||
new_node = node.with_changes(body=node.body.with_changes(body=new_statements))
|
||||
imports_to_keep.append(new_node)
|
||||
existing_protected_statements.update({str(stmt) for stmt in new_statements})
|
||||
else:
|
||||
append_new_import_node(node, unused_imports, imports_to_keep)
|
||||
append_new_import_node(node, unused_imports, added_names, imports_to_keep)
|
||||
|
||||
protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())]
|
||||
usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())]
|
||||
# If the same import is both protected and unprotected, only keep the protected one
|
||||
for protected_node in protected_import_nodes:
|
||||
for stmt_node in protected_node.body.body:
|
||||
usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]]
|
||||
|
||||
# Protected imports always appear at the end of all imports
|
||||
return usual_import_nodes + protected_import_nodes
|
||||
@ -1102,12 +1182,10 @@ class ModularFileMapper(ModuleMapper):
|
||||
Calling the method `create_modules()` after visit will create all modules based on this modular file.
|
||||
"""
|
||||
|
||||
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
|
||||
def __init__(self, python_module, new_name):
|
||||
super().__init__(python_module)
|
||||
# fmt: off
|
||||
self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3`
|
||||
self.given_old_name = given_old_name
|
||||
self.given_new_name = given_new_name
|
||||
|
||||
self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"}
|
||||
self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module}
|
||||
@ -1191,11 +1269,11 @@ class ModularFileMapper(ModuleMapper):
|
||||
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
|
||||
self.visited_modules = {}
|
||||
self.renamers = {}
|
||||
name_prefixes = self.infer_new_model_name()
|
||||
for file, module in self.model_specific_modules.items():
|
||||
file_model_name = file.split(".")[-2]
|
||||
renamer = ReplaceNameTransformer(
|
||||
file_model_name, self.model_name, self.given_old_name, self.given_new_name
|
||||
)
|
||||
new_name = name_prefixes[file]
|
||||
renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name)
|
||||
renamed_module = module.visit(renamer)
|
||||
self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies(
|
||||
renamed_module,
|
||||
@ -1288,6 +1366,87 @@ class ModularFileMapper(ModuleMapper):
|
||||
|
||||
return relative_order
|
||||
|
||||
def infer_new_model_name(self) -> dict:
|
||||
"""Infer whether we are using a model name prefix different from the usual model name as defined from the filename.
|
||||
This is useful e.g. when we define a new multi-modal model, and only the text part inherits from `LlamaModel`,
|
||||
so we have something like:
|
||||
```python
|
||||
class NewModelNameTextDecoderLayer(LlamaDecoderLayer):
|
||||
pass
|
||||
```
|
||||
with the `Text` prefix added to the model name.
|
||||
However, in case of multiple prefix used, we raise a warning and use the most frequent prefix, to avoid parsing
|
||||
the same file multiple times and inconsistencies in the objects added from dependencies.
|
||||
If the new prefix collides with a prefix of another class in the file where we are importing from, then we also
|
||||
raise a warning, and use the default prefix (model name) to avoid collisions in dependencies.
|
||||
"""
|
||||
prefix_model_name_mapping = defaultdict(Counter)
|
||||
cased_default_name = get_cased_name(self.model_name)
|
||||
# Iterate over all new classes to get modeling super classes
|
||||
for class_name, class_node in self.classes.items():
|
||||
modeling_bases = [
|
||||
k.value.value for k in class_node.bases if k.value.value in self.model_specific_imported_objects
|
||||
]
|
||||
if len(modeling_bases) > 1:
|
||||
raise ValueError(
|
||||
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*modeling_bases,}."
|
||||
)
|
||||
if len(modeling_bases) == 1:
|
||||
filename = self.model_specific_imported_objects[modeling_bases[0]]
|
||||
cased_model_name = cased_default_name # the default name prefix
|
||||
suffix = common_partial_suffix(class_name, modeling_bases[0])
|
||||
if len(suffix) > 0 and suffix[0].isupper():
|
||||
cased_model_name = class_name.replace(suffix, "")
|
||||
prefix_model_name_mapping[filename].update([cased_model_name])
|
||||
|
||||
# Check if we found multiple prefixes for some modeling files
|
||||
final_name_mapping = {}
|
||||
for file, prefixes_counter in prefix_model_name_mapping.items():
|
||||
if len(prefixes_counter) > 1:
|
||||
_, total = prefixes_counter.most_common(1)[0]
|
||||
most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total]
|
||||
# if the default name is in the pool of equally used prefixes, use it, otherwise last encountered
|
||||
final_name = cased_default_name if cased_default_name in most_used_entities else most_used_entities[-1]
|
||||
else:
|
||||
final_name = list(prefixes_counter)[0]
|
||||
# Check if the prefix can be used without collisions in the names
|
||||
old_cased_model_name = get_cased_name(file.split(".")[-2])
|
||||
old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name)
|
||||
# Raise adequate warning depending on the situation
|
||||
has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file)
|
||||
if final_name != cased_default_name and has_prefix_collision:
|
||||
if len(prefixes_counter) > 1:
|
||||
logger.warning(
|
||||
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. However, the "
|
||||
f"most used one, '{final_name}', is already present in the source file and will likely cause consistency "
|
||||
f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args "
|
||||
"and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different "
|
||||
f"from '{cased_default_name}') or use a single prefix in all the modular (best)."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"We detected the use of the new default prefix {final_name} when inheriting from {file}. However, it is "
|
||||
"already present in the source file and will likely cause consistency issues. For this reason we fallback "
|
||||
f"to the default prefix '{cased_default_name}' when grabbing args and dependencies. Make sure to subclass "
|
||||
f"the intermediate classes with the prefix you want (if different from '{cased_default_name}')"
|
||||
)
|
||||
final_name = cased_default_name
|
||||
elif len(prefixes_counter) > 1:
|
||||
logger.warning(
|
||||
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. We will only "
|
||||
f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the "
|
||||
f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix "
|
||||
"in all the modular (best)."
|
||||
)
|
||||
final_name_mapping[file] = get_lowercase_name(final_name)
|
||||
|
||||
# Check we are not missing imported files
|
||||
for file in self.model_specific_modules.keys():
|
||||
if file not in final_name_mapping.keys():
|
||||
final_name_mapping[file] = self.model_name
|
||||
|
||||
return final_name_mapping
|
||||
|
||||
|
||||
def check_dependencies_and_create_import_node(
|
||||
file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str
|
||||
@ -1338,11 +1497,11 @@ def get_class_node_and_dependencies(
|
||||
class node based on the inherited classes if needed. Also returns any new imports of a new class defined in
|
||||
the modular that we nay need.
|
||||
"""
|
||||
bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects]
|
||||
if len(bases) > 1:
|
||||
raise ValueError(
|
||||
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}."
|
||||
)
|
||||
# An exception was already raised if this has len > 1
|
||||
model_specific_bases = [
|
||||
k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects
|
||||
]
|
||||
super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None
|
||||
|
||||
file_type = find_file_type(class_name)
|
||||
file_to_update = files[file_type]
|
||||
@ -1352,19 +1511,17 @@ def get_class_node_and_dependencies(
|
||||
imported_objects = modular_mapper.imported_objects_per_file[file_type]
|
||||
|
||||
# We need to replace the class node with the transformers (modeling file) super class node
|
||||
if len(bases) == 1:
|
||||
super_class = bases[0]
|
||||
if super_class is not None:
|
||||
super_file_name = modular_mapper.model_specific_imported_objects[super_class]
|
||||
|
||||
# Get the mapper corresponding to the inherited class
|
||||
mapper = modular_mapper.visited_modules[super_file_name]
|
||||
# Rename the super class according to the exact same rule we used when renaming the whole module
|
||||
renamer = modular_mapper.renamers[super_file_name]
|
||||
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name)
|
||||
renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name)
|
||||
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name)
|
||||
|
||||
# Create the new class node
|
||||
updated_node = replace_class_node(mapper, node, renamed_super_class)
|
||||
updated_node = replace_class_node(mapper, node, renamed_super_class, super_class)
|
||||
|
||||
# Grab all immediate dependencies of the new node
|
||||
new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects)
|
||||
@ -1468,7 +1625,7 @@ def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]:
|
||||
return files
|
||||
|
||||
|
||||
def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None):
|
||||
def convert_modular_file(modular_file):
|
||||
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file)
|
||||
output = {}
|
||||
if pattern is not None:
|
||||
@ -1478,8 +1635,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
|
||||
code = file.read()
|
||||
module = cst.parse_module(code)
|
||||
wrapper = MetadataWrapper(module)
|
||||
if cst_transformers is None:
|
||||
cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name)
|
||||
cst_transformers = ModularFileMapper(module, model_name)
|
||||
wrapper.visit(cst_transformers)
|
||||
for file, module in create_modules(cst_transformers).items():
|
||||
if module != {}:
|
||||
@ -1522,20 +1678,10 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--files_to_parse",
|
||||
default=["src/transformers/models/starcoder2/modular_starcoder2.py"],
|
||||
default=["src/transformers/models/gemma/modular_gemma.py"],
|
||||
nargs="+",
|
||||
help="A list of `modular_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_model_name",
|
||||
required=False,
|
||||
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from modular-file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--new_model_name",
|
||||
required=False,
|
||||
help="The name of the new model being added in CamelCase. If not provided is inferred from modular-file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files_to_parse == ["all"]:
|
||||
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||
@ -1544,5 +1690,5 @@ if __name__ == "__main__":
|
||||
for file_name in find_priority_list(args.files_to_parse):
|
||||
print(f"Converting {file_name} to a single model single file format")
|
||||
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
||||
converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
|
||||
converted_files = convert_modular_file(file_name)
|
||||
converter = save_modeling_file(file_name, converted_files)
|
||||
|
Loading…
Reference in New Issue
Block a user