[modular] Fix the prefix-based renaming if the old and new model share a common name suffix (#37829)

* first try

* Fix and set examples

* style

* fix

* Update modular_test_detr.py

* Update image_processing_new_imgproc_model.py

* Update modular_model_converter.py
This commit is contained in:
Cyril Vallez 2025-04-29 10:43:23 +02:00 committed by GitHub
parent a847d4aa6b
commit 4602059aae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 2340 additions and 939 deletions

View File

@ -4,7 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_new_imgproc_model.py file directly. One of our CI enforces this. # modular_new_imgproc_model.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Optional, Union from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -74,13 +74,13 @@ class ImgprocModelImageProcessor(BaseImageProcessor):
def __init__( def __init__(
self, self,
do_resize: bool = True, do_resize: bool = True,
size: Optional[dict[str, int]] = None, size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True, do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255, rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, list[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True, do_convert_rgb: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
@ -101,7 +101,7 @@ class ImgprocModelImageProcessor(BaseImageProcessor):
def resize( def resize(
self, self,
image: np.ndarray, image: np.ndarray,
size: dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None,
@ -151,13 +151,13 @@ class ImgprocModelImageProcessor(BaseImageProcessor):
self, self,
images: ImageInput, images: ImageInput,
do_resize: Optional[bool] = None, do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None, size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None, resample: PILImageResampling = None,
do_rescale: Optional[bool] = None, do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None, rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None, do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
do_convert_rgb: Optional[bool] = None, do_convert_rgb: Optional[bool] = None,
data_format: ChannelDimension = ChannelDimension.FIRST, data_format: ChannelDimension = ChannelDimension.FIRST,

View File

@ -5,7 +5,7 @@
# modular_add_function.py file directly. One of our CI enforces this. # modular_add_function.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Note that zamba does not have the `apply_rotary_pos_emb` function! # Note that zamba does not have the `apply_rotary_pos_emb` function!
from typing import Optional from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -62,5 +62,5 @@ class TestAttention(nn.Module):
def __init__(self): def __init__(self):
pass pass
def forward(self) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: def forward(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
_ = apply_rotary_pos_emb(1, 1, 1, 1) _ = apply_rotary_pos_emb(1, 1, 1, 1)

View File

@ -4,27 +4,41 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_dummy.py file directly. One of our CI enforces this. # modular_dummy.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
)
from .configuration_dummy import DummyConfig from .configuration_dummy import DummyConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@use_kernel_forward_from_hub("RMSNorm")
class DummyRMSNorm(nn.Module): class DummyRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@ -63,45 +77,18 @@ class DummyRotaryEmbedding(nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad() @torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids): def forward(self, x, position_ids):
if "dynamic" in self.rope_type: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() cos = emb.cos() * self.attention_scaling
sin = emb.sin() sin = emb.sin() * self.attention_scaling
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@ -223,12 +210,12 @@ class DummyAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
@ -245,6 +232,7 @@ class DummyAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@ -270,7 +258,7 @@ class DummyAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
class DummyDecoderLayer(nn.Module): class DummyDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DummyConfig, layer_idx: int): def __init__(self, config: DummyConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -290,11 +278,10 @@ class DummyDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
@ -369,6 +356,8 @@ class DummyPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DummyRMSNorm):
module.weight.data.fill_(1.0)
DUMMY_INPUTS_DOCSTRING = r""" DUMMY_INPUTS_DOCSTRING = r"""
@ -381,12 +370,15 @@ DUMMY_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details. [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids) [What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
but you can also pass a `BlockMask` object directly here.
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
@ -406,20 +398,12 @@ DUMMY_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`. config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): past_key_values (`Cache`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed: It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
@ -480,10 +464,11 @@ class DummyModel(DummyPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(DUMMY_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DUMMY_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None, past_key_values: Optional[Cache] = None,
@ -491,16 +476,14 @@ class DummyModel(DummyPreTrainedModel):
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@ -511,6 +494,10 @@ class DummyModel(DummyPreTrainedModel):
) )
use_cache = False use_cache = False
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
if not isinstance(past_key_values, (type(None), Cache)):
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
@ -543,19 +530,6 @@ class DummyModel(DummyPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask, attention_mask=causal_mask,
@ -579,26 +553,29 @@ class DummyModel(DummyPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None, past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
return output if return_dict else output.to_tuple()
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: torch.Tensor, attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool = False,
): ):
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any(): if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask return attention_mask
return None return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
@ -616,7 +593,7 @@ class DummyModel(DummyPreTrainedModel):
): ):
return None return None
dtype, device = input_tensor.dtype, input_tensor.device dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if using_static_cache: if using_static_cache:
target_length = past_key_values.get_max_cache_shape() target_length = past_key_values.get_max_cache_shape()
@ -633,7 +610,6 @@ class DummyModel(DummyPreTrainedModel):
sequence_length=sequence_length, sequence_length=sequence_length,
target_length=target_length, target_length=target_length,
dtype=dtype, dtype=dtype,
device=device,
cache_position=cache_position, cache_position=cache_position,
batch_size=input_tensor.shape[0], batch_size=input_tensor.shape[0],
) )
@ -641,7 +617,7 @@ class DummyModel(DummyPreTrainedModel):
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"] and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
@ -658,7 +634,6 @@ class DummyModel(DummyPreTrainedModel):
sequence_length: int, sequence_length: int,
target_length: int, target_length: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor, cache_position: torch.Tensor,
batch_size: int, batch_size: int,
**kwargs, **kwargs,
@ -678,8 +653,6 @@ class DummyModel(DummyPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet. to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`): dtype (`torch.dtype`):
The dtype to use for the 4D attention mask. The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`): cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`): batch_size (`torch.Tensor`):
@ -691,11 +664,11 @@ class DummyModel(DummyPreTrainedModel):
else: else:
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = torch.full( causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
) )
if sequence_length != 1: if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -6,7 +6,7 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math import math
import os import os
from typing import Optional, Union from typing import Optional, Tuple, Union
import torch import torch
from packaging import version from packaging import version
@ -136,9 +136,9 @@ class DummyBertSelfAttention(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@ -245,9 +245,9 @@ class DummyBertSdpaSelfAttention(DummyBertSelfAttention):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once( logger.warning_once(
@ -386,9 +386,9 @@ class DummyBertAttention(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -454,9 +454,9 @@ class DummyBertLayer(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@ -532,12 +532,12 @@ class DummyBertEncoder(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@ -626,6 +626,46 @@ class DummyBertPooler(nn.Module):
return pooled_output return pooled_output
class DummyBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class DummyBertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = DummyBertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path): def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model.""" """Load tf checkpoints in a pytorch model."""
try: try:
@ -726,6 +766,8 @@ class DummyBertPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
elif isinstance(module, DummyBertLMPredictionHead):
module.bias.data.zero_()
DUMMY_BERT_START_DOCSTRING = r""" DUMMY_BERT_START_DOCSTRING = r"""

View File

@ -4,28 +4,48 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_from_uppercase_model.py file directly. One of our CI enforces this. # modular_from_uppercase_model.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Optional from typing import Callable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging from ...utils import logging
from .configuration_from_uppercase_model import FromUppercaseModelConfig from .configuration_from_uppercase_model import FromUppercaseModelTextConfig, FromUppercaseModelVisionConfig
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
output_attentions: bool = True,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class FromUppercaseModelAttention(nn.Module): class FromUppercaseModelAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config): def __init__(self, config: Union[FromUppercaseModelVisionConfig, FromUppercaseModelTextConfig]):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -38,253 +58,71 @@ class FromUppercaseModelAttention(nn.Module):
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() batch_size, seq_length, embed_dim = hidden_states.shape
# get query proj queries = self.q_proj(hidden_states)
query_states = self.q_proj(hidden_states) * self.scale keys = self.k_proj(hidden_states)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) values = self.v_proj(hidden_states)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(*proj_shape) values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(*proj_shape) # FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
src_len = key_states.size(1) if self.config._attn_implementation == "flash_attention_2":
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) self.is_causal = causal_attention_mask is not None
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else: else:
attn_weights_reshaped = None if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) elif causal_attention_mask is not None:
attention_mask = causal_attention_mask
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
class FromUppercaseModelFlashAttention2(FromUppercaseModelAttention):
"""
FromUppercaseModelAttention flash attention module. This module inherits from `FromUppercaseModelAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once( logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to" "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
f" {target_dtype}."
) )
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
query_states = query_states.to(target_dtype) attn_output, attn_weights = attention_interface(
key_states = key_states.to(target_dtype) self,
value_states = value_states.to(target_dtype) queries,
keys,
attn_output = _flash_attention_forward( values,
query_states,
key_states,
value_states,
attention_mask, attention_mask,
q_len, is_causal=self.is_causal,
dropout=dropout_rate, scaling=self.scale,
is_causal=causal_attention_mask is not None, dropout=0.0 if not self.training else self.dropout,
use_top_left_mask=self._flash_attn_uses_top_left_mask, output_attentions=output_attentions,
) )
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
return attn_output, attn_weights return attn_output, attn_weights
class FromUppercaseModelSdpaAttention(FromUppercaseModelAttention):
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`FromUppercaseModelAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from FromUppercaseModelAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"FromUppercaseModelModel is using FromUppercaseModelSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
'be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask`
if attention_mask is not None and causal_attention_mask is not None:
attn_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attn_mask = causal_attention_mask
else:
attn_mask = attention_mask
bsz, tgt_len, embed_dim = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# FROM_UPPERCASE_MODEL text model uses both `causal_attention_mask` and `attention_mask` sequentially.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
scale=self.scale,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
class FromUppercaseModelMLP(nn.Module): class FromUppercaseModelMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
@ -300,18 +138,11 @@ class FromUppercaseModelMLP(nn.Module):
return hidden_states return hidden_states
FROM_UPPERCASE_MODEL_ATTENTION_CLASSES = {
"eager": FromUppercaseModelAttention,
"sdpa": FromUppercaseModelSdpaAttention,
"flash_attention_2": FromUppercaseModelFlashAttention2,
}
class FromUppercaseModelEncoderLayer(nn.Module): class FromUppercaseModelEncoderLayer(nn.Module):
def __init__(self, config: FromUppercaseModelConfig): def __init__(self, config: Union[FromUppercaseModelVisionConfig, FromUppercaseModelTextConfig]):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = FROM_UPPERCASE_MODEL_ATTENTION_CLASSES[config._attn_implementation](config) self.self_attn = FromUppercaseModelAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = FromUppercaseModelMLP(config) self.mlp = FromUppercaseModelMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@ -322,7 +153,7 @@ class FromUppercaseModelEncoderLayer(nn.Module):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor, causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.FloatTensor]: ) -> Tuple[torch.FloatTensor]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`

View File

@ -4,27 +4,41 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_multimodal1.py file directly. One of our CI enforces this. # modular_multimodal1.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from functools import partial from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Union
import torch import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
)
from .configuration_multimodal1 import Multimodal1TextConfig from .configuration_multimodal1 import Multimodal1TextConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@use_kernel_forward_from_hub("RMSNorm")
class Multimodal1TextRMSNorm(nn.Module): class Multimodal1TextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@ -63,45 +77,18 @@ class Multimodal1TextRotaryEmbedding(nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad() @torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids): def forward(self, x, position_ids):
if "dynamic" in self.rope_type: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() cos = emb.cos() * self.attention_scaling
sin = emb.sin() sin = emb.sin() * self.attention_scaling
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@ -223,12 +210,12 @@ class Multimodal1TextAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
@ -245,6 +232,7 @@ class Multimodal1TextAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@ -270,7 +258,7 @@ class Multimodal1TextAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
class Multimodal1TextDecoderLayer(nn.Module): class Multimodal1TextDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Multimodal1TextConfig, layer_idx: int): def __init__(self, config: Multimodal1TextConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -290,11 +278,10 @@ class Multimodal1TextDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
@ -369,6 +356,8 @@ class Multimodal1TextPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Multimodal1TextRMSNorm):
module.weight.data.fill_(1.0)
MULTIMODAL1_TEXT_INPUTS_DOCSTRING = r""" MULTIMODAL1_TEXT_INPUTS_DOCSTRING = r"""
@ -381,12 +370,15 @@ MULTIMODAL1_TEXT_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details. [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids) [What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
but you can also pass a `BlockMask` object directly here.
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
@ -406,20 +398,12 @@ MULTIMODAL1_TEXT_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`. config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): past_key_values (`Cache`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed: It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
@ -480,10 +464,11 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(MULTIMODAL1_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MULTIMODAL1_TEXT_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None, past_key_values: Optional[Cache] = None,
@ -491,16 +476,14 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@ -511,6 +494,10 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
) )
use_cache = False use_cache = False
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
if not isinstance(past_key_values, (type(None), Cache)):
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
@ -543,19 +530,6 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask, attention_mask=causal_mask,
@ -579,26 +553,29 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None, past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
return output if return_dict else output.to_tuple()
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: torch.Tensor, attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool = False,
): ):
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any(): if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask return attention_mask
return None return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
@ -616,7 +593,7 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
): ):
return None return None
dtype, device = input_tensor.dtype, input_tensor.device dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if using_static_cache: if using_static_cache:
target_length = past_key_values.get_max_cache_shape() target_length = past_key_values.get_max_cache_shape()
@ -633,7 +610,6 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
sequence_length=sequence_length, sequence_length=sequence_length,
target_length=target_length, target_length=target_length,
dtype=dtype, dtype=dtype,
device=device,
cache_position=cache_position, cache_position=cache_position,
batch_size=input_tensor.shape[0], batch_size=input_tensor.shape[0],
) )
@ -641,7 +617,7 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"] and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
@ -658,7 +634,6 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
sequence_length: int, sequence_length: int,
target_length: int, target_length: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor, cache_position: torch.Tensor,
batch_size: int, batch_size: int,
**kwargs, **kwargs,
@ -678,8 +653,6 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet. to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`): dtype (`torch.dtype`):
The dtype to use for the 4D attention mask. The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`): cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`): batch_size (`torch.Tensor`):
@ -691,11 +664,11 @@ class Multimodal1TextModel(Multimodal1TextPreTrainedModel):
else: else:
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = torch.full( causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
) )
if sequence_length != 1: if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -5,7 +5,7 @@
# modular_multimodal2.py file directly. One of our CI enforces this. # modular_multimodal2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Optional, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -14,30 +14,48 @@ from transformers.utils import add_start_docstrings
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
from ...utils import ( from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, can_return_tuple,
is_flash_attn_greater_or_equal_2_10,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
torch_int, torch_int,
) )
from .configuration_multimodal2 import Multimodal2Config, Multimodal2VisionConfig from .configuration_multimodal2 import Multimodal2Config, Multimodal2TextConfig, Multimodal2VisionConfig
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
output_attentions: bool = True,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class Multimodal2VisionAttention(nn.Module): class Multimodal2VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config): def __init__(self, config: Union[Multimodal2VisionConfig, Multimodal2TextConfig]):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
@ -50,250 +68,68 @@ class Multimodal2VisionAttention(nn.Module):
) )
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size() batch_size, seq_length, embed_dim = hidden_states.shape
# get query proj queries = self.q_proj(hidden_states)
query_states = self.q_proj(hidden_states) * self.scale keys = self.k_proj(hidden_states)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) values = self.v_proj(hidden_states)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(*proj_shape) values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(*proj_shape) # MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
src_len = key_states.size(1) if self.config._attn_implementation == "flash_attention_2":
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) self.is_causal = causal_attention_mask is not None
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else: else:
attn_weights_reshaped = None if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attn_output = torch.bmm(attn_probs, value_states) if self.config._attn_implementation == "sdpa" and output_attentions:
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
class Multimodal2VisionSdpaAttention(Multimodal2VisionAttention):
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Multimodal2VisionAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Multimodal2VisionAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once( logger.warning_once(
"Multimodal2VisionModel is using Multimodal2VisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
'be removed using the argument `attn_implementation="eager"` when loading the model.'
) )
return super().forward( else:
hidden_states=hidden_states, attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask, attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
# MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask` attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
if attention_mask is not None and causal_attention_mask is not None:
attn_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attn_mask = causal_attention_mask
else:
attn_mask = attention_mask
bsz, tgt_len, embed_dim = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# MULTIMODAL2_VISION text model uses both `causal_attention_mask` and `attention_mask` sequentially.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
scale=self.scale,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
class Multimodal2VisionFlashAttention2(Multimodal2VisionAttention):
"""
Multimodal2VisionAttention flash attention module. This module inherits from `Multimodal2VisionAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
dropout_rate = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=causal_attention_mask is not None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
return attn_output, attn_weights return attn_output, attn_weights
@ -312,18 +148,92 @@ class Multimodal2VisionMLP(nn.Module):
return hidden_states return hidden_states
MULTIMODAL2_VISION_ATTENTION_CLASSES = { class Multimodal2Attention(nn.Module):
"eager": Multimodal2VisionAttention, """Multi-headed attention from 'Attention Is All You Need' paper"""
"sdpa": Multimodal2VisionSdpaAttention,
"flash_attention_2": Multimodal2VisionFlashAttention2, def __init__(self, config: Union[Multimodal2VisionConfig, Multimodal2TextConfig]):
} super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# MULTIMODAL2 text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
else:
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
)
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class Multimodal2VisionEncoderLayer(nn.Module): class Multimodal2VisionEncoderLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.self_attn = MULTIMODAL2_VISION_ATTENTION_CLASSES[config._attn_implementation](config) self.self_attn = Multimodal2Attention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Multimodal2VisionMLP(config) self.mlp = Multimodal2VisionMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@ -334,7 +244,7 @@ class Multimodal2VisionEncoderLayer(nn.Module):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor, causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.FloatTensor]: ) -> Tuple[torch.FloatTensor]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@ -384,6 +294,7 @@ class Multimodal2VisionEncoder(nn.Module):
self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@can_return_tuple
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
@ -391,8 +302,7 @@ class Multimodal2VisionEncoder(nn.Module):
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, ) -> BaseModelOutput:
) -> Union[tuple, BaseModelOutput]:
r""" r"""
Args: Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@ -426,7 +336,6 @@ class Multimodal2VisionEncoder(nn.Module):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
@ -459,10 +368,10 @@ class Multimodal2VisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
) )
@ -578,6 +487,7 @@ class Multimodal2VisionTransformer(nn.Module):
self.encoder = Multimodal2VisionEncoder(config) self.encoder = Multimodal2VisionEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
@can_return_tuple
@add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig)
def forward( def forward(
@ -585,9 +495,8 @@ class Multimodal2VisionTransformer(nn.Module):
pixel_values: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False, interpolate_pos_encoding: Optional[bool] = False,
) -> Union[tuple, BaseModelOutputWithPooling]: ) -> BaseModelOutputWithPooling:
r""" r"""
Returns: Returns:
@ -596,7 +505,6 @@ class Multimodal2VisionTransformer(nn.Module):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None: if pixel_values is None:
raise ValueError("You have to specify pixel_values") raise ValueError("You have to specify pixel_values")
@ -604,20 +512,16 @@ class Multimodal2VisionTransformer(nn.Module):
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states) hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder( encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
last_hidden_state = encoder_outputs[0] last_hidden_state = encoder_outputs.last_hidden_state
pooled_output = last_hidden_state[:, 0, :] pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output) pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
pooler_output=pooled_output, pooler_output=pooled_output,
@ -662,6 +566,7 @@ class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding return self.vision_model.embeddings.patch_embedding
@can_return_tuple
@add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MULTIMODAL2_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Multimodal2VisionConfig)
def forward( def forward(
@ -670,8 +575,7 @@ class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False, interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, ) -> BaseModelOutputWithPooling:
) -> Union[tuple, BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
@ -694,12 +598,10 @@ class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state >>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states >>> pooled_output = outputs.pooler_output # pooled CLS states
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.vision_model( return self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding, interpolate_pos_encoding=interpolate_pos_encoding,
) )

View File

@ -4,7 +4,7 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_my_new_model2.py file directly. One of our CI enforces this. # modular_my_new_model2.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Callable, Optional, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -13,14 +13,27 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
)
from .configuration_my_new_model2 import MyNewModel2Config from .configuration_my_new_model2 import MyNewModel2Config
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -78,45 +91,18 @@ class MyNewModel2RotaryEmbedding(nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad() @torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids): def forward(self, x, position_ids):
if "dynamic" in self.rope_type: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() cos = emb.cos() * self.attention_scaling
sin = emb.sin() sin = emb.sin() * self.attention_scaling
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@ -222,12 +208,12 @@ class MyNewModel2Attention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
@ -244,6 +230,7 @@ class MyNewModel2Attention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@ -269,7 +256,7 @@ class MyNewModel2Attention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
class MyNewModel2DecoderLayer(nn.Module): class MyNewModel2DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: MyNewModel2Config, layer_idx: int): def __init__(self, config: MyNewModel2Config, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -289,11 +276,10 @@ class MyNewModel2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
@ -368,6 +354,8 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MyNewModel2RMSNorm):
module.weight.data.fill_(1.0)
MY_NEW_MODEL2_INPUTS_DOCSTRING = r""" MY_NEW_MODEL2_INPUTS_DOCSTRING = r"""
@ -380,12 +368,15 @@ MY_NEW_MODEL2_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details. [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids) [What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
but you can also pass a `BlockMask` object directly here.
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
@ -405,20 +396,12 @@ MY_NEW_MODEL2_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`. config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): past_key_values (`Cache`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed: It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
@ -479,27 +462,26 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(MY_NEW_MODEL2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MY_NEW_MODEL2_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwarg for now **kwargs, # NOOP kwarg for now
) -> Union[tuple, BaseModelOutputWithPast]: ) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@ -549,19 +531,6 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask, attention_mask=causal_mask,
@ -584,26 +553,29 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None, past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
return output if return_dict else output.to_tuple()
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: torch.Tensor, attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool = False,
): ):
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any(): if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask return attention_mask
return None return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
@ -621,7 +593,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
): ):
return None return None
dtype, device = input_tensor.dtype, input_tensor.device dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if using_static_cache: if using_static_cache:
target_length = past_key_values.get_max_cache_shape() target_length = past_key_values.get_max_cache_shape()
@ -638,7 +610,6 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
sequence_length=sequence_length, sequence_length=sequence_length,
target_length=target_length, target_length=target_length,
dtype=dtype, dtype=dtype,
device=device,
cache_position=cache_position, cache_position=cache_position,
batch_size=input_tensor.shape[0], batch_size=input_tensor.shape[0],
) )
@ -646,7 +617,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"] and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
@ -663,7 +634,6 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
sequence_length: int, sequence_length: int,
target_length: int, target_length: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor, cache_position: torch.Tensor,
batch_size: int, batch_size: int,
**kwargs, **kwargs,
@ -683,8 +653,6 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet. to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`): dtype (`torch.dtype`):
The dtype to use for the 4D attention mask. The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`): cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`): batch_size (`torch.Tensor`):
@ -696,11 +664,11 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
else: else:
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = torch.full( causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
) )
if sequence_length != 1: if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
@ -747,29 +715,28 @@ class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.model.embed_tokens = value self.model.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(MY_NEW_MODEL2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MY_NEW_MODEL2_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, ) -> SequenceClassifierOutputWithPast:
) -> Union[tuple, SequenceClassifierOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model( transformer_outputs: BaseModelOutputWithPast = self.model(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@ -778,9 +745,8 @@ class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states) logits = self.score(hidden_states)
if input_ids is not None: if input_ids is not None:
@ -795,7 +761,7 @@ class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel):
elif input_ids is not None: elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device) token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else: else:
last_non_pad_token = -1 last_non_pad_token = -1
@ -810,10 +776,6 @@ class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel):
if labels is not None: if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast( return SequenceClassifierOutputWithPast(
loss=loss, loss=loss,
logits=pooled_logits, logits=pooled_logits,

View File

@ -5,7 +5,7 @@
# modular_new_task_model.py file directly. One of our CI enforces this. # modular_new_task_model.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional, Union from typing import ClassVar, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -59,10 +59,10 @@ class NewTaskModelCausalLMOutputWithPast(ModelOutput):
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None image_hidden_states: Optional[torch.FloatTensor] = None
@ -113,23 +113,12 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
# important: this ported version of NewTaskModelisn't meant for training from scratch - only # important: this ported version of NewTaskModelisn't meant for training from scratch - only
# inference and fine-tuning # inference and fine-tuning
std = ( std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"): if isinstance(module, nn.Linear):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
NEW_TASK_MODEL_INPUTS_DOCSTRING = r""" NEW_TASK_MODEL_INPUTS_DOCSTRING = r"""
@ -251,19 +240,22 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask, attention_mask,
token_type_ids, token_type_ids=None,
past_key_values, past_key_values=None,
cache_position, cache_position=None,
input_tensor, input_tensor=None,
is_training: bool = False, is_training: Optional[bool] = None,
): ):
if self.config.text_config._attn_implementation == "flash_attention_2": if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask: if attention_mask is not None and 0.0 in attention_mask:
return attention_mask return attention_mask
return None return None
is_training = is_training if is_training is not None else self.training
using_static_cache = isinstance(past_key_values, StaticCache) using_static_cache = isinstance(past_key_values, StaticCache)
min_dtype = torch.finfo(self.dtype).min min_dtype = torch.finfo(self.dtype).min
if input_tensor is None:
input_tensor = attention_mask
inputs_lead_dim, sequence_length = input_tensor.shape[:2] inputs_lead_dim, sequence_length = input_tensor.shape[:2]
if using_static_cache: if using_static_cache:
target_length = past_key_values.get_max_cache_shape() target_length = past_key_values.get_max_cache_shape()
@ -298,6 +290,8 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
# First unmask prefix tokens during training # First unmask prefix tokens during training
if is_training: if is_training:
if token_type_ids is None:
raise ValueError("Token type ids must be provided during training")
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
) )
@ -345,7 +339,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0, num_logits_to_keep: int = 0,
) -> Union[tuple, NewTaskModelCausalLMOutputWithPast]: ) -> Union[Tuple, NewTaskModelCausalLMOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -368,19 +362,19 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
>>> import requests >>> import requests
>>> from transformers import AutoProcessor, NewTaskModelForNewTask >>> from transformers import AutoProcessor, NewTaskModelForNewTask
>>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf") >>> model = NewTaskModelForNewTask.from_pretrained("google/new_task_model2-3b-mix-224")
>>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf") >>> processor = AutoProcessor.from_pretrained("google/new_task_model2-3b-mix-224")
>>> prompt = "answer en Where is the cow standing?" >>> prompt = "Where is the cat standing?"
>>> url = "https://huggingface.co/gv-hf/NewTaskModel-test-224px-hf/resolve/main/cow_beach_1.png" >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate >>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30) >>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach" "Where is the cat standing?\nsnow"
``` ```
Returns: Returns:
""" """

View File

@ -6,7 +6,7 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math import math
import os import os
from typing import Optional, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -139,9 +139,9 @@ class RobertaSelfAttention(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys # If this is instantiated as a cross-attention module, the keys
@ -248,9 +248,9 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once( logger.warning_once(
@ -389,9 +389,9 @@ class RobertaAttention(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -457,9 +457,9 @@ class RobertaLayer(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
@ -535,12 +535,12 @@ class RobertaEncoder(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@ -629,6 +629,46 @@ class RobertaPooler(nn.Module):
return pooled_output return pooled_output
class RobertaPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class RobertaLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = RobertaPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
def load_tf_weights_in_roberta(model, config, tf_checkpoint_path): def load_tf_weights_in_roberta(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model.""" """Load tf checkpoints in a pytorch model."""
try: try:
@ -729,6 +769,8 @@ class RobertaPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
elif isinstance(module, RobertaLMPredictionHead):
module.bias.data.zero_()
ROBERTA_START_DOCSTRING = r""" ROBERTA_START_DOCSTRING = r"""
@ -861,12 +903,12 @@ class RobertaModel(RobertaPreTrainedModel):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r""" r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if

View File

@ -4,26 +4,42 @@
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_super.py file directly. One of our CI enforces this. # modular_super.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Callable, Optional, Union from typing import Callable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers.modeling_outputs import CausalLMOutputWithPast
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache from ...cache_utils import Cache, StaticCache
from ...integrations import use_kernel_forward_from_hub
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torch_flex_attn_available,
logging,
)
from .configuration_super import SuperConfig from .configuration_super import SuperConfig
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@use_kernel_forward_from_hub("RMSNorm")
class SuperRMSNorm(nn.Module): class SuperRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@ -62,45 +78,18 @@ class SuperRotaryEmbedding(nn.Module):
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad() @torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids): def forward(self, x, position_ids):
if "dynamic" in self.rope_type: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
self._dynamic_frequency_update(position_ids, device=x.device)
# Core RoPE block
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
device_type = x.device.type device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() cos = emb.cos() * self.attention_scaling
sin = emb.sin() sin = emb.sin() * self.attention_scaling
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@ -222,12 +211,12 @@ class SuperAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
@ -244,6 +233,7 @@ class SuperAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(
@ -269,7 +259,7 @@ class SuperAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
class SuperDecoderLayer(nn.Module): class SuperDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: SuperConfig, layer_idx: int): def __init__(self, config: SuperConfig, layer_idx: int):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -289,11 +279,10 @@ class SuperDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
@ -368,6 +357,8 @@ class SuperPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, SuperRMSNorm):
module.weight.data.fill_(1.0)
SUPER_INPUTS_DOCSTRING = r""" SUPER_INPUTS_DOCSTRING = r"""
@ -380,12 +371,15 @@ SUPER_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details. [`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids) [What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
but you can also pass a `BlockMask` object directly here.
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
@ -405,20 +399,12 @@ SUPER_INPUTS_DOCSTRING = r"""
config.n_positions - 1]`. config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): past_key_values (`Cache`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed: It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
@ -479,6 +465,7 @@ class SuperModel(SuperPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(SUPER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SUPER_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
@ -492,7 +479,7 @@ class SuperModel(SuperPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
) -> Union[tuple, BaseModelOutputWithPast]: ) -> Union[tuple, CausalLMOutputWithPast]:
out = super().forward( out = super().forward(
input_ids, input_ids,
attention_mask, attention_mask,
@ -510,16 +497,20 @@ class SuperModel(SuperPreTrainedModel):
def _update_causal_mask( def _update_causal_mask(
self, self,
attention_mask: torch.Tensor, attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
cache_position: torch.Tensor, cache_position: torch.Tensor,
past_key_values: Cache, past_key_values: Cache,
output_attentions: bool, output_attentions: bool = False,
): ):
if self.config._attn_implementation == "flash_attention_2": if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any(): if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask return attention_mask
return None return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
@ -537,7 +528,7 @@ class SuperModel(SuperPreTrainedModel):
): ):
return None return None
dtype, device = input_tensor.dtype, input_tensor.device dtype = input_tensor.dtype
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if using_static_cache: if using_static_cache:
target_length = past_key_values.get_max_cache_shape() target_length = past_key_values.get_max_cache_shape()
@ -554,7 +545,6 @@ class SuperModel(SuperPreTrainedModel):
sequence_length=sequence_length, sequence_length=sequence_length,
target_length=target_length, target_length=target_length,
dtype=dtype, dtype=dtype,
device=device,
cache_position=cache_position, cache_position=cache_position,
batch_size=input_tensor.shape[0], batch_size=input_tensor.shape[0],
) )
@ -562,7 +552,7 @@ class SuperModel(SuperPreTrainedModel):
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"] and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions and not output_attentions
): ):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
@ -579,7 +569,6 @@ class SuperModel(SuperPreTrainedModel):
sequence_length: int, sequence_length: int,
target_length: int, target_length: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor, cache_position: torch.Tensor,
batch_size: int, batch_size: int,
**kwargs, **kwargs,
@ -599,8 +588,6 @@ class SuperModel(SuperPreTrainedModel):
to account for the 0 padding, the part of the cache that is not filled yet. to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`): dtype (`torch.dtype`):
The dtype to use for the 4D attention mask. The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`): cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`): batch_size (`torch.Tensor`):
@ -612,11 +599,11 @@ class SuperModel(SuperPreTrainedModel):
else: else:
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
causal_mask = torch.full( causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
) )
if sequence_length != 1: if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

View File

@ -5,7 +5,7 @@
# modular_switch_function.py file directly. One of our CI enforces this. # modular_switch_function.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Note that llama and cohere have different definitions for rotate_half # Note that llama and cohere have different definitions for rotate_half
from typing import Callable, Optional from typing import Callable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -123,12 +123,12 @@ class SwitchFunctionAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
@ -145,6 +145,7 @@ class SwitchFunctionAttention(nn.Module):
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once( logger.warning_once(

File diff suppressed because it is too large Load Diff

View File

@ -16,9 +16,7 @@ from transformers.models.clip.modeling_clip import (
CLIPAttention, CLIPAttention,
CLIPEncoder, CLIPEncoder,
CLIPEncoderLayer, CLIPEncoderLayer,
CLIPFlashAttention2,
CLIPPreTrainedModel, CLIPPreTrainedModel,
CLIPSdpaAttention,
CLIPVisionModel, CLIPVisionModel,
CLIPVisionTransformer, CLIPVisionTransformer,
) )
@ -29,23 +27,6 @@ class Multimodal2VisionAttention(CLIPAttention):
pass pass
# Check that adding the second base class correctly set the parent, even though in clip it does not have the "Vision" part
class Multimodal2VisionSdpaAttention(CLIPSdpaAttention, Multimodal2VisionAttention):
pass
# Check that adding the second base class correctly set the parent, even though in clip it does not have the "Vision" part
class Multimodal2VisionFlashAttention2(CLIPFlashAttention2, Multimodal2VisionAttention):
pass
MULTIMODAL2_VISION_ATTENTION_CLASSES = {
"eager": Multimodal2VisionAttention,
"sdpa": Multimodal2VisionSdpaAttention,
"flash_attention_2": Multimodal2VisionFlashAttention2,
}
class Multimodal2VisionMLP(CLIPMLP): class Multimodal2VisionMLP(CLIPMLP):
pass pass
@ -53,7 +34,6 @@ class Multimodal2VisionMLP(CLIPMLP):
class Multimodal2VisionEncoderLayer(CLIPEncoderLayer): class Multimodal2VisionEncoderLayer(CLIPEncoderLayer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.self_attn = MULTIMODAL2_VISION_ATTENTION_CLASSES[config._attn_implementation](config)
self.mlp = Multimodal2VisionMLP(config) self.mlp = Multimodal2VisionMLP(config)

View File

@ -0,0 +1,7 @@
from transformers.models.deformable_detr.modeling_deformable_detr import DeformableDetrModel
# Here, the old and new model have by essence a common "detr" suffix. Make sure everything is correctly named
# in this case (i.e., we do not wrongly detect `Detr` as part of a suffix to remove)
class TestDetrModel(DeformableDetrModel):
pass

View File

@ -1466,6 +1466,10 @@ class ModularFileMapper(ModuleMapper):
suffix = common_partial_suffix(class_name, modeling_bases[0]) suffix = common_partial_suffix(class_name, modeling_bases[0])
if len(suffix) > 0 and suffix[0].isupper(): if len(suffix) > 0 and suffix[0].isupper():
cased_model_name = class_name.replace(suffix, "") cased_model_name = class_name.replace(suffix, "")
# If both the old model and new model share the last part of their name, is is detected as a common
# suffix, but it should not be the case -> use the full name in this case
if len(cased_model_name) < len(cased_default_name) and cased_default_name in class_name:
cased_model_name = cased_default_name
prefix_model_name_mapping[filename].update([cased_model_name]) prefix_model_name_mapping[filename].update([cased_model_name])
# Check if we found multiple prefixes for some modeling files # Check if we found multiple prefixes for some modeling files
@ -1761,6 +1765,17 @@ if __name__ == "__main__":
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
if args.files_to_parse == ["examples"]: if args.files_to_parse == ["examples"]:
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True) args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True)
else:
for i, model_name in enumerate(args.files_to_parse):
if os.sep not in model_name:
full_path = os.path.join("src", "transformers", "models", model_name, f"modular_{model_name}.py")
# If it does not exist, try in the examples section
if not os.path.isfile(full_path):
full_path = os.path.join("examples", "modular-transformers", f"modular_{model_name}.py")
# We did not find it anywhere
if not os.path.isfile(full_path):
raise ValueError(f"Cannot find a modular file for {model_name}. Please provide the full path.")
args.files_to_parse[i] = full_path
priority_list, _ = find_priority_list(args.files_to_parse) priority_list, _ = find_priority_list(args.files_to_parse)
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted" assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted"