diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 8f552cfc815..ebddfcfe711 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -673,7 +673,7 @@ class AriaPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) - _supports_attention_backend = False + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 561f94e4e73..348da86bb14 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1299,7 +1299,7 @@ class AriaPreTrainedModel(LlamaPreTrainedModel): config_class = AriaConfig base_model_prefix = "" _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) - _supports_attention_backend = False + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index e074d4b1193..519800f9285 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -282,7 +282,6 @@ class AyaVisionModel(AyaVisionPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - image_sizes: torch.Tensor = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, AyaVisionModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -310,7 +309,6 @@ class AyaVisionModel(AyaVisionPreTrainedModel): pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, ) if input_ids is None: diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index 7dcd206bab8..533a7f44447 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -24,12 +24,14 @@ from transformers.models.llava.modeling_llava import ( LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaModel, + LlavaModelOutputWithPast, LlavaPreTrainedModel, ) from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack -from ...utils import logging +from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from .configuration_aya_vision import AyaVisionConfig @@ -110,10 +112,154 @@ class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass -class AyaVisionModel(LlavaModel): +class AyaVisionModelOutputWithPast(LlavaModelOutputWithPast): pass +class AyaVisionModel(LlavaModel): + # Unlike LLaVA, the model doesn't have to deal with Pixtral-style image states + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, List[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") + + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + # For default; crop CLS from each hidden state in the hidden state pool + if vision_feature_select_strategy == "default": + hs_pool = [hs[:, 1:] for hs in hs_pool] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(selected_image_feature) + return image_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, AyaVisionModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_tokens = (input_ids == self.config.image_token_id).sum() + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return AyaVisionModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration): def forward( self, diff --git a/src/transformers/models/chameleon/configuration_chameleon.py b/src/transformers/models/chameleon/configuration_chameleon.py index 5955ef48940..70d9a7cd7a0 100644 --- a/src/transformers/models/chameleon/configuration_chameleon.py +++ b/src/transformers/models/chameleon/configuration_chameleon.py @@ -247,6 +247,7 @@ class ChameleonConfig(PretrainedConfig): self.vq_config = ChameleonVQVAEConfig(**vq_config) self.vocabulary_map = vocabulary_map + self.image_token_id = vocabulary_map.get("") if vocabulary_map is not None else None super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index e7a46b43cf3..c816fe61053 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -904,12 +904,6 @@ class ChameleonModel(ChameleonPreTrainedModel): self.embed_tokens = value def get_image_tokens(self, pixel_values: torch.FloatTensor): - logger.warning( - "`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`" - ) - return self.get_image_featues(pixel_values) - - def get_image_features(self, pixel_values: torch.FloatTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts obtained image tokens into BPE tokens and wraps with "boi" and "eoi" @@ -925,6 +919,19 @@ class ChameleonModel(ChameleonPreTrainedModel): bpe_toks = bpe_toks.view(batch_size, -1) return bpe_toks + def get_image_features(self, pixel_values: torch.FloatTensor): + """ + Tokenizes images into discrete tokens with VQGAN module and embeds + them with text embeddings layer + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + """ + image_tokens = self.get_image_tokens(pixel_values) + vision_embeddings = self.get_input_embeddings()(image_tokens) + return vision_embeddings + @auto_docstring def forward( self, @@ -963,7 +970,7 @@ class ChameleonModel(ChameleonPreTrainedModel): ) if pixel_values is not None: - image_tokens = self.get_image_features(pixel_values) + image_tokens = self.get_image_tokens(pixel_values) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel(): n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() diff --git a/src/transformers/models/emu3/configuration_emu3.py b/src/transformers/models/emu3/configuration_emu3.py index 60e5e55ab44..19315003dfb 100644 --- a/src/transformers/models/emu3/configuration_emu3.py +++ b/src/transformers/models/emu3/configuration_emu3.py @@ -320,6 +320,7 @@ class Emu3Config(PretrainedConfig): self.vq_config = vq_config self.text_config = text_config self.vocabulary_map = vocabulary_map + self.image_token_id = vocabulary_map.get("") if vocabulary_map is not None else None super().__init__(**kwargs) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 31f01db1b5a..6540e9fc714 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1451,12 +1451,6 @@ class Emu3Model(Emu3PreTrainedModel): self.text_model.set_input_embeddings(value) def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): - logger.warning( - "`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`" - ) - return self.get_image_featues(pixel_values) - - def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts obtained image tokens into BPE tokens and wraps with "boi" and "eoi" @@ -1473,6 +1467,24 @@ class Emu3Model(Emu3PreTrainedModel): bpe_tokens = torch.cat(bpe_tokens_list) return bpe_tokens + def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module and embeds + them with text embeddings layer + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + """ + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + split_sizes = [ + (height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1) + for height, width in image_sizes + ] + image_features = self.get_input_embeddings()(image_tokens) + image_features = torch.split(image_features, split_sizes) + return image_features + @torch.no_grad def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): """ @@ -1533,7 +1545,7 @@ class Emu3Model(Emu3PreTrainedModel): ) if pixel_values is not None: - image_tokens = self.get_image_features(pixel_values, image_sizes) + image_tokens = self.get_image_tokens(pixel_values, image_sizes) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 8c86f81d523..d6e34eb14af 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -938,12 +938,6 @@ class Emu3Model(Emu3PreTrainedModel): self.text_model.set_input_embeddings(value) def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): - logger.warning( - "`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`" - ) - return self.get_image_featues(pixel_values) - - def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): """ Tokenizes images into discrete tokens with VQGAN module. Converts obtained image tokens into BPE tokens and wraps with "boi" and "eoi" @@ -960,6 +954,24 @@ class Emu3Model(Emu3PreTrainedModel): bpe_tokens = torch.cat(bpe_tokens_list) return bpe_tokens + def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module and embeds + them with text embeddings layer + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + """ + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + split_sizes = [ + (height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1) + for height, width in image_sizes + ] + image_features = self.get_input_embeddings()(image_tokens) + image_features = torch.split(image_features, split_sizes) + return image_features + @torch.no_grad def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): """ @@ -1020,7 +1032,7 @@ class Emu3Model(Emu3PreTrainedModel): ) if pixel_values is not None: - image_tokens = self.get_image_features(pixel_values, image_sizes) + image_tokens = self.get_image_tokens(pixel_values, image_sizes) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) diff --git a/src/transformers/models/fuyu/configuration_fuyu.py b/src/transformers/models/fuyu/configuration_fuyu.py index 50a1975c763..4584184b0ae 100644 --- a/src/transformers/models/fuyu/configuration_fuyu.py +++ b/src/transformers/models/fuyu/configuration_fuyu.py @@ -16,7 +16,7 @@ from ...configuration_utils import PretrainedConfig from ...utils import logging -from ..auto import CONFIG_MAPPING +from ..auto import CONFIG_MAPPING, AutoConfig logger = logging.get_logger(__name__) @@ -89,6 +89,8 @@ class FuyuConfig(PretrainedConfig): The id of the *beginning-of-sequence* token. eos_token_id (`Union[int, List[int]]`, *optional*, defaults to 2): The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + image_token_id (`int`, *optional*, defaults to 71011): + The id of the image placeholder token. text_config (`dict`, *optional*): Dictionary of configuration options used to initialize the `language``[`Aut`]. @@ -100,6 +102,7 @@ class FuyuConfig(PretrainedConfig): ```""" model_type = "fuyu" + sub_configs = {"text_config": AutoConfig} keys_to_ignore_at_inference = ["past_key_values"] def __init__( @@ -127,6 +130,7 @@ class FuyuConfig(PretrainedConfig): pad_token_id=None, bos_token_id=1, eos_token_id=2, + image_token_id=71011, text_config=None, **kwargs, ): @@ -176,6 +180,7 @@ class FuyuConfig(PretrainedConfig): self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout self.partial_rotary_factor = partial_rotary_factor + self.image_token_id = image_token_id self._rope_scaling_validation() super().__init__( diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 590cf8f8d17..78b7ae7d4db 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -202,11 +202,12 @@ class FuyuModel(FuyuPreTrainedModel): inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if image_patches is not None and past_key_values is None: patch_embeddings = self.get_image_features(image_patches) - inputs_embeds = self.gather_continuous_embeddings( - word_embeddings=inputs_embeds, - continuous_embeddings=patch_embeddings, - image_patch_input_indices=image_patches_indices, - ) + patch_embeddings = torch.cat(patch_embeddings, dim=0) + + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings) outputs = self.language_model( inputs_embeds=inputs_embeds, diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index 4852f3aaf9e..3c71c34c6f9 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -620,12 +620,16 @@ class FuyuProcessor(ProcessorMixin): width_scale_factor = padded_width / image_size[1] optimal_scale_factor = min(height_scale_factor, width_scale_factor) - # We can use torch here because Fuyu processor has hard dependency on torch + image_unpadded_h = min(int(image_size[0] * optimal_scale_factor), image_size[0]) + image_unpadded_w = min(int(image_size[0] * optimal_scale_factor), image_size[0]) + + # We can use torch here because Fuyu processor has hard dependency on torch. NOTE: Fuyu can't do multi-image + # thus the below (1, 1, 1) is hardcoded. Same as when calling the processor model_image_input = self.image_processor.preprocess_with_tokenizer_info( image_input=torch.zeros(1, 1, 3, padded_height, padded_width), image_present=torch.ones(1, 1, 1), - image_unpadded_h=torch.tensor([[int(image_size[0] * optimal_scale_factor)]]), - image_unpadded_w=torch.tensor([[int(image_size[1] * optimal_scale_factor)]]), + image_unpadded_h=torch.tensor([[image_unpadded_h]]), + image_unpadded_w=torch.tensor([[image_unpadded_w]]), image_placeholder_id=0, # dummy ids, we can be sure `id=0` is never out-of-range image_newline_id=0, variable_sized=True, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 2a296089198..95b50039eb9 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -377,7 +377,7 @@ class GemmaModel(GemmaPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, # NOOP kwarg for now + **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -446,6 +446,7 @@ class GemmaModel(GemmaPreTrainedModel): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index e934df7ef80..648caea73d7 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -23,7 +23,9 @@ from torch import nn from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast +from ...processing_utils import Unpack from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging from ..llama.modeling_llama import ( @@ -378,7 +380,7 @@ class GemmaModel(LlamaModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, # NOOP kwarg for now + **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -447,6 +449,7 @@ class GemmaModel(LlamaModel): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 6bcdd3594e4..41043b3e745 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -693,8 +693,10 @@ class Idefics3Model(Idefics3PreTrainedModel): - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states. """ special_image_token_mask = input_ids == self.image_token_id - # Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. + # Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. new_inputs_embeds = inputs_embeds.clone() + # Flatten `image_hidden_states` if not flat yet + image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1]) # cast to the dtype of the input_embeds to support quantized models image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype) new_inputs_embeds[special_image_token_mask] = image_hidden_states @@ -742,7 +744,6 @@ class Idefics3Model(Idefics3PreTrainedModel): # Modality projection & resampling image_hidden_states = self.connector(image_hidden_states.last_hidden_state) - image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1]) return image_hidden_states @can_return_tuple @@ -807,9 +808,6 @@ class Idefics3Model(Idefics3PreTrainedModel): past_key_values = DynamicCache() past_seen_tokens = past_key_values.get_seq_length() - if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: - raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") - if inputs_embeds is None: inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device) @@ -821,7 +819,7 @@ class Idefics3Model(Idefics3PreTrainedModel): elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) - if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: + if past_seen_tokens == 0 and input_ids is not None and image_hidden_states is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images # that simply don't exist inputs_embeds = self.inputs_merger( diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index c73f84d9222..b2e539e7bce 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -627,8 +627,8 @@ class InternVLModel(InternVLPreTrainedModel): def get_image_features( self, pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, **kwargs, ): """ @@ -642,6 +642,15 @@ class InternVLModel(InternVLPreTrainedModel): Returns: vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`. """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + downsample_ratio = self.config.downsample_ratio if vision_feature_layer == -1: vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state @@ -666,7 +675,6 @@ class InternVLModel(InternVLPreTrainedModel): # Project features through multi-modal projector vision_features = self.multi_modal_projector(vision_features) - return vision_features @can_return_tuple @@ -686,7 +694,6 @@ class InternVLModel(InternVLPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - image_sizes: torch.Tensor = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, InternVLModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -714,7 +721,6 @@ class InternVLModel(InternVLPreTrainedModel): pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, ) if input_ids is None: diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 91e53ec5233..fcf4958f623 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -27,7 +27,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, logging, torch_int +from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging, torch_int from ..clip.modeling_clip import CLIPMLP from ..janus.modeling_janus import JanusVisionAttention from ..llama.modeling_llama import LlamaRMSNorm @@ -35,6 +35,7 @@ from ..llava.modeling_llava import ( LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaModel, + LlavaModelOutputWithPast, LlavaPreTrainedModel, ) from .configuration_internvl import InternVLConfig, InternVLVisionConfig @@ -510,6 +511,10 @@ class InternVLMultiModalProjector(nn.Module): return hidden_states +class InternVLModelOutputWithPast(LlavaModelOutputWithPast): + pass + + class InternVLModel(LlavaModel): def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5): """Perform pixel shuffle downsampling on vision features. @@ -549,8 +554,8 @@ class InternVLModel(LlavaModel): def get_image_features( self, pixel_values: torch.FloatTensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, **kwargs, ): """ @@ -564,6 +569,15 @@ class InternVLModel(LlavaModel): Returns: vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`. """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + downsample_ratio = self.config.downsample_ratio if vision_feature_layer == -1: vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state @@ -588,9 +602,94 @@ class InternVLModel(LlavaModel): # Project features through multi-modal projector vision_features = self.multi_modal_projector(vision_features) - return vision_features + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, InternVLModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0] + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_tokens = (input_ids == self.config.image_token_id).sum() + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_tokens = (input_ids == self.config.image_token_id).sum() + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return InternVLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): pass diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 1fcb00e6e58..0ab2333d1ba 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -233,6 +233,15 @@ class LlavaModel(LlavaPreTrainedModel): selected_image_feature = torch.cat(hs_pool, dim=-1) image_features = self.multi_modal_projector(selected_image_feature) + + if "image_sizes" in kwargs: + split_sizes = [ + (height // self.vision_tower.patch_size) * (width // self.vision_tower.patch_size) + for height, width in kwargs["image_sizes"] + ] + image_features = torch.split(image_features.squeeze(0), split_sizes) + else: + image_features = list(image_features) return image_features @can_return_tuple @@ -282,6 +291,7 @@ class LlavaModel(LlavaPreTrainedModel): vision_feature_select_strategy=vision_feature_select_strategy, image_sizes=image_sizes, ) + image_features = torch.cat(image_features, dim=0) if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index fb56a9b1748..bd1e1aae253 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -357,9 +357,8 @@ class LlavaNextModel(LlavaNextPreTrainedModel): image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) - image_features = torch.cat(new_image_features, dim=0) - feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) - return image_features, feature_lens + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) + return new_image_features, feature_lens def get_image_features( self, @@ -429,6 +428,14 @@ class LlavaNextModel(LlavaNextPreTrainedModel): image_features = self.multi_modal_projector(selected_image_feature) image_features = torch.split(image_features, image_num_patches, dim=0) + + # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy=vision_feature_select_strategy, + image_newline=self.image_newline, + ) return image_features @can_return_tuple @@ -489,14 +496,7 @@ class LlavaNextModel(LlavaNextPreTrainedModel): vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - - # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - vision_feature_select_strategy=vision_feature_select_strategy, - image_newline=self.image_newline, - ) + image_features = torch.cat(image_features, dim=0) special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index c38ce78bc99..20b1647c0bd 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -411,9 +411,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) - image_features = torch.cat(new_image_features, dim=0) - feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) - return image_features, feature_lens + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) + return new_image_features, feature_lens def get_image_features( self, @@ -482,6 +481,13 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): selected_image_feature = selected_image_feature image_features = self.multi_modal_projector(selected_image_feature) image_features = torch.split(image_features, image_num_patches, dim=0) + + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy, + image_newline=self.image_newline, + ) return image_features @can_return_tuple @@ -544,12 +550,7 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): vision_feature_layer=self.vision_feature_layer, vision_feature_select_strategy=self.vision_feature_select_strategy, ) - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - self.vision_feature_select_strategy, - image_newline=self.image_newline, - ) + image_features = torch.cat(image_features, dim=0) special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 91d16caab6b..e5d2fd92cb4 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -315,6 +315,13 @@ class LlavaNextVideoModel(LlavaNextModel): selected_image_feature = selected_image_feature image_features = self.multi_modal_projector(selected_image_feature) image_features = torch.split(image_features, image_num_patches, dim=0) + + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + vision_feature_select_strategy, + image_newline=self.image_newline, + ) return image_features def get_video_features( @@ -430,12 +437,7 @@ class LlavaNextVideoModel(LlavaNextModel): vision_feature_layer=self.vision_feature_layer, vision_feature_select_strategy=self.vision_feature_select_strategy, ) - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - self.vision_feature_select_strategy, - image_newline=self.image_newline, - ) + image_features = torch.cat(image_features, dim=0) special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 9d40b2cef4d..a43674d2942 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -409,18 +409,19 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): image_feature = image_feature[0] if image_newline is not None: image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + image_feature = image_feature.flatten(0, 1) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) - image_features = torch.cat(new_image_features, dim=0) - feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) - return image_features, feature_lens + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) + return new_image_features, feature_lens def get_image_features( self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + vision_aspect_ratio: Optional[str] = None, batch_num_images: Optional[torch.LongTensor] = None, ): """ @@ -444,6 +445,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches and are of shape `(num_patches, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_aspect_ratio = ( + vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio + ) + # ! infer image_num_patches from image_sizes if batch_num_images is None: # treat this as a single-image case for backward compatibility @@ -483,6 +496,13 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): selected_image_feature = selected_image_feature image_features = self.multi_modal_projector(selected_image_feature) image_features = torch.split(image_features, image_num_patches, dim=0) + + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + image_newline=self.image_newline, + vision_aspect_ratio=vision_aspect_ratio, + ) return image_features @can_return_tuple @@ -564,12 +584,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): vision_feature_select_strategy=vision_feature_select_strategy, batch_num_images=batch_num_images, ) - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - image_newline=self.image_newline, - vision_aspect_ratio=vision_aspect_ratio, - ) + image_features = torch.cat(image_features, dim=0) special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index f838c5f703f..e61eca445e9 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -316,11 +316,11 @@ class LlavaOnevisionModel(LlavaNextVideoModel): image_feature = image_feature[0] if image_newline is not None: image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + image_feature = image_feature.flatten(0, 1) new_image_features.append(image_feature) feature_lens.append(image_feature.size(0)) - image_features = torch.cat(new_image_features, dim=0) - feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) - return image_features, feature_lens + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device) + return new_image_features, feature_lens def apply_pooling(self, image_features): height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size @@ -340,8 +340,9 @@ class LlavaOnevisionModel(LlavaNextVideoModel): self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor, - vision_feature_layer: Union[int, List[int]], - vision_feature_select_strategy: str, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + vision_aspect_ratio: Optional[str] = None, batch_num_images: Optional[torch.LongTensor] = None, ): """ @@ -365,6 +366,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel): image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches and are of shape `(num_patches, image_length, embed_dim)`). """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + vision_aspect_ratio = ( + vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio + ) + # ! infer image_num_patches from image_sizes if batch_num_images is None: # treat this as a single-image case for backward compatibility @@ -404,6 +417,13 @@ class LlavaOnevisionModel(LlavaNextVideoModel): selected_image_feature = selected_image_feature image_features = self.multi_modal_projector(selected_image_feature) image_features = torch.split(image_features, image_num_patches, dim=0) + + image_features, feature_lens = self.pack_image_features( + image_features, + image_sizes, + image_newline=self.image_newline, + vision_aspect_ratio=vision_aspect_ratio, + ) return image_features def get_video_features( @@ -529,12 +549,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel): vision_feature_select_strategy=vision_feature_select_strategy, batch_num_images=batch_num_images, ) - image_features, feature_lens = self.pack_image_features( - image_features, - image_sizes, - image_newline=self.image_newline, - vision_aspect_ratio=vision_aspect_ratio, - ) + image_features = torch.cat(image_features, dim=0) special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 4da2570090b..c5a5205947f 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -285,6 +285,9 @@ class Mistral3Model(Mistral3PreTrainedModel): selected_image_feature = torch.cat(hs_pool, dim=-1) image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) + downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size + split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes] + image_features = torch.split(image_features.squeeze(0), split_sizes) return image_features @can_return_tuple @@ -332,6 +335,7 @@ class Mistral3Model(Mistral3PreTrainedModel): vision_feature_layer=vision_feature_layer, image_sizes=image_sizes, ) + image_features = torch.cat(image_features, dim=0) special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index 69232f1ec86..666b4cde4d5 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -170,6 +170,9 @@ class Mistral3Model(LlavaModel): selected_image_feature = torch.cat(hs_pool, dim=-1) image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) + downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size + split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes] + image_features = torch.split(image_features.squeeze(0), split_sizes) return image_features def forward( @@ -215,6 +218,7 @@ class Mistral3Model(LlavaModel): vision_feature_layer=vision_feature_layer, image_sizes=image_sizes, ) + image_features = torch.cat(image_features, dim=0) special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 3f8983083e2..a5376d651ce 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -19,8 +19,7 @@ # limitations under the License. """PyTorch Persimmon model.""" -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -30,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -37,7 +37,8 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging from .configuration_persimmon import PersimmonConfig @@ -137,6 +138,29 @@ class PersimmonMLP(nn.Module): return hidden_states +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class PersimmonAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -166,6 +190,7 @@ class PersimmonAttention(nn.Module): self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) self.qk_layernorm = config.qk_layernorm + self.scaling = self.head_dim**-0.5 if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( @@ -203,6 +228,7 @@ class PersimmonAttention(nn.Module): use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -249,27 +275,22 @@ class PersimmonAttention(nn.Module): } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) - attn_weights = self.attention_dropout(attn_weights) - - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.config.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.dense(attn_output) if not output_attentions: @@ -298,6 +319,7 @@ class PersimmonDecoderLayer(nn.Module): use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -337,6 +359,7 @@ class PersimmonDecoderLayer(nn.Module): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -369,6 +392,9 @@ class PersimmonPreTrainedModel(PreTrainedModel): _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -430,6 +456,7 @@ class PersimmonModel(PersimmonPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -512,6 +539,7 @@ class PersimmonModel(PersimmonPreTrainedModel): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] @@ -758,6 +786,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, + **kwargs, ) hidden_states = outputs.last_hidden_state diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 8a15fa8e1e5..9225d4df6b2 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -257,7 +257,7 @@ class PixtralProcessor(ProcessorMixin): num_image_tokens = [] for height, width in image_sizes: resized_height, resized_width = get_resize_output_image_size( - image=np.zeros((height, width, 3)), + np.zeros((height, width, 3)), size=(size["longest_edge"], size["longest_edge"]), patch_size=(patch_size, patch_size), ) diff --git a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py index 06db2e83f1f..c790cceefc9 100644 --- a/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation from ...utils import logging @@ -257,6 +257,8 @@ class Qwen2_5OmniTextConfig(PretrainedConfig): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 28): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + layer_types (`list`, *optional*): + Attention pattern for each layer. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. rope_scaling (`Dict`, *optional*): @@ -358,6 +360,7 @@ class Qwen2_5OmniTextConfig(PretrainedConfig): use_sliding_window=False, sliding_window=32768, max_window_layers=28, + layer_types=None, attention_dropout=0.0, **kwargs, ): @@ -396,6 +399,16 @@ class Qwen2_5OmniTextConfig(PretrainedConfig): if self.rope_scaling is None: self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"} + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + class Qwen2_5OmniThinkerConfig(PretrainedConfig): r""" diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 84e55346533..be1c55051be 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -22,7 +22,7 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -31,19 +31,20 @@ from torch import nn from torch.nn import Parameter from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, check_torch_load_is_safe, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, - is_torch_flex_attn_available, logging, ) from ...utils.hub import cached_file @@ -68,15 +69,6 @@ else: apply_rotary_emb = None -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -110,7 +102,8 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True - _supports_static_cache = True + _supports_static_cache = False + _supports_attention_backend = True def _init_weights(self, module): # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only @@ -1444,6 +1437,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Qwen2_5OmniAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -1469,11 +1488,13 @@ class Qwen2_5OmniAttention(nn.Module): self.is_causal = True self.attention_dropout = config.attention_dropout self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) @@ -1487,6 +1508,7 @@ class Qwen2_5OmniAttention(nn.Module): use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -1507,40 +1529,24 @@ class Qwen2_5OmniAttention(nn.Module): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # Fix precision issues in Qwen2-VL float16 inference - # Replace inf values with zeros in attention weights to prevent NaN propagation - if query_states.dtype == torch.float16: - attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value @@ -1558,216 +1564,7 @@ class Qwen2MLP(nn.Module): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) -class Qwen2_5OmniFlashAttention2(Qwen2_5OmniAttention): - """ - Qwen2_5Omni flash attention module, following Qwen2_5Omni attention module. This module inherits from `Qwen2_5OmniAttention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2_5OmniSdpaAttention(Qwen2_5OmniAttention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2_5OmniModel is using Qwen2_5OmniSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_5_OMNI_ATTENTION_CLASSES = { - "eager": Qwen2_5OmniAttention, - "flash_attention_2": Qwen2_5OmniFlashAttention2, - "sdpa": Qwen2_5OmniSdpaAttention, -} - - -class Qwen2_5OmniDecoderLayer(nn.Module): +class Qwen2_5OmniDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1777,11 +1574,12 @@ class Qwen2_5OmniDecoderLayer(nn.Module): f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = QWEN2_5_OMNI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = Qwen2_5OmniAttention(config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -1793,7 +1591,7 @@ class Qwen2_5OmniDecoderLayer(nn.Module): use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1831,6 +1629,7 @@ class Qwen2_5OmniDecoderLayer(nn.Module): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -1868,6 +1667,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -1892,6 +1692,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1930,9 +1731,23 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds @@ -1952,7 +1767,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + causal_mask_mapping[decoder_layer.attention_type], position_ids, past_key_values, output_attentions, @@ -1963,13 +1778,14 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1997,161 +1813,6 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinker. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen2_5OmniConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen25OmniThinkerConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - @auto_docstring( custom_intro=""" @@ -2398,9 +2059,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) else: @@ -2574,6 +2232,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -2598,6 +2257,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2636,9 +2296,23 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds @@ -2658,7 +2332,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + causal_mask_mapping[decoder_layer.attention_type], position_ids, past_key_values, output_attentions, @@ -2669,13 +2343,14 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] @@ -2703,161 +2378,6 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5Omni. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen2_5OmniConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen2_5OmniConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): config_class = Qwen2_5OmniTalkerConfig diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index a75c0fafa40..3ced88af768 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -41,7 +41,7 @@ from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2Audio from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_rope_utils import rope_config_validation @@ -298,6 +298,8 @@ class Qwen2_5OmniTextConfig(PretrainedConfig): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 28): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + layer_types (`list`, *optional*): + Attention pattern for each layer. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. rope_scaling (`Dict`, *optional*): @@ -399,6 +401,7 @@ class Qwen2_5OmniTextConfig(PretrainedConfig): use_sliding_window=False, sliding_window=32768, max_window_layers=28, + layer_types=None, attention_dropout=0.0, **kwargs, ): @@ -437,6 +440,16 @@ class Qwen2_5OmniTextConfig(PretrainedConfig): if self.rope_scaling is None: self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"} + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + class Qwen2_5OmniThinkerConfig(PretrainedConfig): r""" @@ -1104,7 +1117,7 @@ class Qwen2_5OmniConfig(PretrainedConfig): class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): config_class = Qwen2_5OmniConfig - _supports_static_cache = True + _supports_static_cache = False def _init_weights(self, module): # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only @@ -2184,11 +2197,13 @@ class Qwen2_5OmniAttention(Qwen2_5_VLAttention, nn.Module): self.is_causal = True self.attention_dropout = config.attention_dropout self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) @@ -2450,9 +2465,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) else: diff --git a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py index a700a992886..c053de54b39 100644 --- a/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py @@ -23,7 +23,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation @@ -117,6 +117,8 @@ class Qwen2_5_VLTextConfig(PretrainedConfig): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 80): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + layer_types (`list`, *optional*): + Attention pattern for each layer. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. rope_scaling (`Dict`, *optional*): @@ -211,6 +213,7 @@ class Qwen2_5_VLTextConfig(PretrainedConfig): use_sliding_window=False, sliding_window=4096, max_window_layers=80, + layer_types=None, attention_dropout=0.0, rope_scaling=None, image_token_id=None, @@ -224,7 +227,7 @@ class Qwen2_5_VLTextConfig(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None self.max_window_layers = max_window_layers # for backward compatibility @@ -240,6 +243,16 @@ class Qwen2_5_VLTextConfig(PretrainedConfig): self.attention_dropout = attention_dropout self.rope_scaling = rope_scaling + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. # and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index f2c71f991fe..b9d8985ebbc 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -26,21 +26,23 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig @@ -48,15 +50,6 @@ if is_flash_attn_available(): from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - logger = logging.get_logger(__name__) @@ -359,6 +352,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.get_text_config().initializer_range @@ -681,6 +675,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Qwen2_5_VLAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -706,6 +726,7 @@ class Qwen2_5_VLAttention(nn.Module): self.is_causal = True self.attention_dropout = config.attention_dropout self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -716,6 +737,7 @@ class Qwen2_5_VLAttention(nn.Module): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) @@ -729,6 +751,7 @@ class Qwen2_5_VLAttention(nn.Module): use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -749,253 +772,28 @@ class Qwen2_5_VLAttention(nn.Module): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # Fix precision issues in Qwen2-VL float16 inference - # Replace inf values with zeros in attention weights to prevent NaN propagation - if query_states.dtype == torch.float16: - attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): - """ - Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2_5_VLModel is using Qwen2_5_VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_5_VL_ATTENTION_CLASSES = { - "eager": Qwen2_5_VLAttention, - "flash_attention_2": Qwen2_5_VLFlashAttention2, - "sdpa": Qwen2_5_VLSdpaAttention, -} - - -class Qwen2_5_VLDecoderLayer(nn.Module): +class Qwen2_5_VLDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1005,11 +803,12 @@ class Qwen2_5_VLDecoderLayer(nn.Module): f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -1021,7 +820,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module): use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1059,6 +858,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -1095,6 +895,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -1119,6 +920,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1157,9 +959,23 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds @@ -1179,7 +995,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + causal_mask_mapping[decoder_layer.attention_type], position_ids, past_key_values, output_attentions, @@ -1190,13 +1006,14 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1224,160 +1041,8 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): attentions=all_self_attns, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen2_5_VLConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen2_5_VLConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @auto_docstring @@ -1598,6 +1263,8 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = torch.split(video_embeds, split_sizes) return video_embeds def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): @@ -1612,6 +1279,8 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): """ pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) return image_embeds @auto_docstring @@ -1633,6 +1302,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]: r""" pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): @@ -1659,6 +1329,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0) n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_embeds.shape[0] if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: @@ -1676,6 +1347,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0) n_video_tokens = (input_ids == self.config.video_token_id).sum() n_video_features = video_embeds.shape[0] if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: @@ -1691,15 +1363,14 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - if position_ids is None: - attention_mask_2d = attention_mask - if attention_mask is not None and attention_mask.ndim == 4: - attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2) - attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min - attention_mask_2d = (1.0 - attention_mask_2d).int() + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() # Calculate RoPE index once per generation in the pre-fill stage only. # When compiling, we can't check tensor values thus we check only input length @@ -1719,7 +1390,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): image_grid_thw, video_grid_thw, second_per_grid_ts=second_per_grid_ts, - attention_mask=attention_mask_2d, + attention_mask=attention_mask_tensor, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids @@ -1748,6 +1419,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + **kwargs, ) output = Qwen2_5_VLModelOutputWithPast( @@ -1759,61 +1431,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): ) return output if return_dict else output.to_tuple() - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - @dataclass class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): @@ -1916,6 +1533,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1988,6 +1606,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 530d74d51b5..e38f2952144 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -30,6 +30,7 @@ import torch.utils.checkpoint from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + KwargsForCausalLM, PatchEmbed, PatchMerger, Qwen2RMSNorm, @@ -622,6 +623,7 @@ class Qwen2_5_VLModel(Qwen2VLModel): rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]: r""" pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): @@ -648,6 +650,7 @@ class Qwen2_5_VLModel(Qwen2VLModel): inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0) n_image_tokens = (input_ids == self.config.image_token_id).sum() n_image_features = image_embeds.shape[0] if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: @@ -665,6 +668,7 @@ class Qwen2_5_VLModel(Qwen2VLModel): if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0) n_video_tokens = (input_ids == self.config.video_token_id).sum() n_video_features = video_embeds.shape[0] if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: @@ -680,15 +684,14 @@ class Qwen2_5_VLModel(Qwen2VLModel): video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - if position_ids is None: - attention_mask_2d = attention_mask - if attention_mask is not None and attention_mask.ndim == 4: - attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2) - attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min - attention_mask_2d = (1.0 - attention_mask_2d).int() + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() # Calculate RoPE index once per generation in the pre-fill stage only. # When compiling, we can't check tensor values thus we check only input length @@ -708,7 +711,7 @@ class Qwen2_5_VLModel(Qwen2VLModel): image_grid_thw, video_grid_thw, second_per_grid_ts=second_per_grid_ts, - attention_mask=attention_mask_2d, + attention_mask=attention_mask_tensor, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids @@ -737,6 +740,7 @@ class Qwen2_5_VLModel(Qwen2VLModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + **kwargs, ) output = Qwen2_5_VLModelOutputWithPast( @@ -774,6 +778,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -846,6 +851,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] diff --git a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py index 5f4842ce604..03bfa66c41f 100644 --- a/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/configuration_qwen2_vl.py @@ -14,7 +14,7 @@ # limitations under the License. """Qwen2VL model configuration""" -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation from ...modeling_rope_utils import rope_config_validation from ...utils import logging @@ -106,6 +106,8 @@ class Qwen2VLTextConfig(PretrainedConfig): Sliding window attention (SWA) window size. If not specified, will default to `4096`. max_window_layers (`int`, *optional*, defaults to 80): The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + layer_types (`list`, *optional*): + Attention pattern for each layer. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. rope_scaling (`Dict`, *optional*): @@ -200,6 +202,7 @@ class Qwen2VLTextConfig(PretrainedConfig): use_sliding_window=False, sliding_window=4096, max_window_layers=80, + layer_types=None, attention_dropout=0.0, rope_scaling=None, image_token_id=None, @@ -213,7 +216,7 @@ class Qwen2VLTextConfig(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window + self.sliding_window = sliding_window if self.use_sliding_window else None self.max_window_layers = max_window_layers # for backward compatibility @@ -229,6 +232,16 @@ class Qwen2VLTextConfig(PretrainedConfig): self.attention_dropout = attention_dropout self.rope_scaling = rope_scaling + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, move it to 'rope_type'. # and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index b9f166074c8..4581ed08279 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -21,7 +21,7 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -30,24 +30,27 @@ import torch.utils.checkpoint from torch.nn import LayerNorm from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + auto_docstring, + can_return_tuple, + is_torchdynamo_compiling, + logging, +) from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_varlen_func - -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...modeling_flash_attention_utils import flash_attn_varlen_func logger = logging.get_logger(__name__) @@ -516,6 +519,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Qwen2VLAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -541,6 +570,7 @@ class Qwen2VLAttention(nn.Module): self.is_causal = True self.attention_dropout = config.attention_dropout self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -551,6 +581,7 @@ class Qwen2VLAttention(nn.Module): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) @@ -564,6 +595,7 @@ class Qwen2VLAttention(nn.Module): use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -584,253 +616,28 @@ class Qwen2VLAttention(nn.Module): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # Fix precision issues in Qwen2-VL float16 inference - # Replace inf values with zeros in attention weights to prevent NaN propagation - if query_states.dtype == torch.float16: - attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2VLFlashAttention2(Qwen2VLAttention): - """ - Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class Qwen2VLSdpaAttention(Qwen2VLAttention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_VL_ATTENTION_CLASSES = { - "eager": Qwen2VLAttention, - "flash_attention_2": Qwen2VLFlashAttention2, - "sdpa": Qwen2VLSdpaAttention, -} - - -class Qwen2VLDecoderLayer(nn.Module): +class Qwen2VLDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2VLTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -840,11 +647,12 @@ class Qwen2VLDecoderLayer(nn.Module): f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = Qwen2VLAttention(config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] def forward( self, @@ -856,7 +664,7 @@ class Qwen2VLDecoderLayer(nn.Module): use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -894,6 +702,7 @@ class Qwen2VLDecoderLayer(nn.Module): use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -925,6 +734,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.get_text_config().initializer_range @@ -1053,6 +863,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): self._attn_implementation = config._attn_implementation self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -1077,6 +888,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1115,9 +927,23 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds @@ -1137,7 +963,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + causal_mask_mapping[decoder_layer.attention_type], position_ids, past_key_values, output_attentions, @@ -1148,13 +974,14 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = layer_outputs[0] @@ -1182,162 +1009,8 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Qwen2VL - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2VL. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Qwen2VL - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen2VLConfig, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen2VLConfig`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - text_config = config.get_text_config() - if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( - cache_position.reshape(-1, 1) - text_config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @auto_docstring @@ -1523,6 +1196,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): """ pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + video_embeds = torch.split(video_embeds, split_sizes) return video_embeds def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): @@ -1537,6 +1212,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): """ pixel_values = pixel_values.type(self.visual.dtype) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) return image_embeds @auto_docstring @@ -1557,6 +1234,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Qwen2VLModelOutputWithPast]: r""" pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): @@ -1581,6 +1259,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: @@ -1598,6 +1277,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: @@ -1613,15 +1293,14 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - if position_ids is None: - attention_mask_2d = attention_mask - if attention_mask is not None and attention_mask.ndim == 4: - attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2) - attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min - attention_mask_2d = (1.0 - attention_mask_2d).int() + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() # Calculate RoPE index once per generation in the pre-fill stage only. # When compiling, we can't check tensor values thus we check only input length @@ -1637,7 +1316,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): ) if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask_2d + input_ids, image_grid_thw, video_grid_thw, attention_mask_tensor ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids @@ -1663,6 +1342,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, + **kwargs, ) output = Qwen2VLModelOutputWithPast( @@ -1674,62 +1354,6 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): ) return output if return_dict else output.to_tuple() - @staticmethod - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = { @@ -1792,6 +1416,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1861,6 +1486,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index dbe93bb6160..70abf26ee20 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -181,7 +181,9 @@ class VipLlavaModel(VipLlavaPreTrainedModel): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]): + def get_image_features( + self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, List[int]]] = None + ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -194,6 +196,9 @@ class VipLlavaModel(VipLlavaPreTrainedModel): Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + vision_feature_layers = ( + vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers + ) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # If multiple feature layers are provided (which is usually the case) diff --git a/src/transformers/models/vipllava/modular_vipllava.py b/src/transformers/models/vipllava/modular_vipllava.py index 57df2fd7006..5ae33e771c6 100644 --- a/src/transformers/models/vipllava/modular_vipllava.py +++ b/src/transformers/models/vipllava/modular_vipllava.py @@ -71,7 +71,9 @@ class VipLlavaPreTrainedModel(LlavaPreTrainedModel): class VipLlavaModel(LlavaModel): - def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]): + def get_image_features( + self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, List[int]]] = None + ): """ Obtains image last hidden states from the vision tower and apply multimodal projection. @@ -84,6 +86,9 @@ class VipLlavaModel(LlavaModel): Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ + vision_feature_layers = ( + vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers + ) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # If multiple feature layers are provided (which is usually the case) diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py index b6d93ef73f2..ddbb8d02d9a 100644 --- a/tests/models/emu3/test_modeling_emu3.py +++ b/tests/models/emu3/test_modeling_emu3.py @@ -326,9 +326,7 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline def setUp(self): self.model_tester = Emu3Vision2TextModelTester(self) - self.config_tester = ConfigTester( - self, config_class=Emu3Config, has_text_modality=False, common_properties=["vocabulary_map"] - ) + self.config_tester = ConfigTester(self, config_class=Emu3Config, has_text_modality=False, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 238da721bf6..9fc992049a4 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -46,7 +46,10 @@ SPECIAL_CASES_TO_ALLOW = { ], "Qwen2Config": ["use_sliding_window", "max_window_layers"], "Qwen2MoeConfig": ["use_sliding_window"], - "Qwen2VLConfig": ["use_sliding_window"], + "Qwen2VLTextConfig": ["use_sliding_window", "max_window_layers"], + "Qwen2_5_VLTextConfig": ["use_sliding_window", "max_window_layers"], + "Qwen2_5OmniTextConfig": ["use_sliding_window", "max_window_layers"], + "Qwen2_5OmniTalkerConfig": ["use_sliding_window", "max_window_layers"], "Qwen3Config": ["max_window_layers", "use_sliding_window"], # now use `layer_types` instead "Qwen3MoeConfig": ["max_window_layers", "use_sliding_window"], # `cache_implementation` should be in the default generation config, but we don't yet support per-model