diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 649447ca8f7..13da327dab0 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -508,6 +508,22 @@ def _flash_attention_forward( query_states, key_states, value_states, target_dtype ) + # We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + # under two cases: + # Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to + # use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility + # to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information + is_fa2_with_position_ids = ( + position_ids is not None + and query_states.shape[0] == 1 + and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) + ) + is_fa2_with_varlen_kwargs = all( + kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) + ) + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -531,14 +547,7 @@ def _flash_attention_forward( ) attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length) - # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing - # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. - # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif ( - position_ids is not None - and query_states.shape[0] == 1 - and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) - ): + elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids: batch_size = query_states.size(0) if cu_seq_lens_q is None or cu_seq_lens_k is None: diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 65ec7f0b79c..5f0e737b43d 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -279,14 +279,15 @@ def eager_attention_forward( class Glm4vVisionAttention(nn.Module): def __init__(self, config: Glm4vVisionConfig) -> None: super().__init__() - self.config = config + self.dim = config.hidden_size self.num_heads = config.num_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = 1 - self.scale = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = config.attention_dropout self.is_causal = False def forward( @@ -295,23 +296,31 @@ class Glm4vVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[FlashAttentionKwargs], + attention_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] query_states, key_states, value_states = ( self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) ) - - cos, sin = position_embeddings + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos() + sin = emb.sin() + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) - - attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -322,13 +331,17 @@ class Glm4vVisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, + attention_mask=attention_mask, dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scale, - is_causal=self.is_causal, + scaling=self.scaling, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, **kwargs, ) - attn_output = attn_output.squeeze(0) + attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) return attn_output @@ -348,6 +361,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -355,6 +369,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -452,6 +467,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb, pos_ids + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: """ Args: @@ -481,14 +515,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens) for blk in self.blocks: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) hidden_states = self.post_layernorm(hidden_states) diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index c6ca61b2153..d95e457fc4e 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -50,8 +50,8 @@ from ..qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLPreTrainedModel, Qwen2_5_VLRotaryEmbedding, Qwen2_5_VLTextModel, + Qwen2_5_VLVisionAttention, Qwen2_5_VLVisionBlock, - apply_rotary_pos_emb_vision, ) from ..qwen2_5_vl.processing_qwen2_5_vl import ( Qwen2_5_VLProcessor, @@ -505,62 +505,12 @@ class Glm4vVisionEmbeddings(nn.Module): return embeddings -class Glm4vVisionAttention(nn.Module): +class Glm4vVisionAttention(Qwen2_5_VLVisionAttention): def __init__(self, config: Glm4vVisionConfig) -> None: super().__init__() - self.config = config - self.num_heads = config.num_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = 1 - self.scale = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) - self.is_causal = False - - def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: Optional[torch.Tensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> torch.Tensor: - seq_length = hidden_states.shape[0] - query_states, key_states, value_states = ( - self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - ) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) - - query_states = query_states.transpose(0, 1).unsqueeze(0) - key_states = key_states.transpose(0, 1).unsqueeze(0) - value_states = value_states.transpose(0, 1).unsqueeze(0) - - attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, _ = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scale, - is_causal=self.is_causal, - **kwargs, - ) - attn_output = attn_output.squeeze(0) - attn_output = attn_output.reshape(seq_length, -1).contiguous() - attn_output = self.proj(attn_output) - return attn_output class Glm4vVisionBlock(Qwen2_5_VLVisionBlock): @@ -653,6 +603,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb, pos_ids + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: """ Args: @@ -682,14 +651,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens) for blk in self.blocks: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + ) hidden_states = self.post_layernorm(hidden_states) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index c63beb73fac..f7e5f5ba4a5 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -607,6 +607,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): f" and `num_heads`: {self.num_heads})." ) self.scaling = self.head_dim**-0.5 + self.attention_dropout = 0.0 self.is_decoder = False self.is_causal = False @@ -619,6 +620,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -634,15 +636,6 @@ class Qwen2_5OmniAudioAttention(nn.Module): value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attention_mask = torch.full( - [1, 1, seq_length, key_states.shape[-2]], - torch.finfo(query_states.dtype).min, - device=query_states.device, - dtype=query_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -652,13 +645,13 @@ class Qwen2_5OmniAudioAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -686,6 +679,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -704,6 +698,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, + attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -785,6 +780,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): def set_input_embeddings(self, value: nn.Module): self.conv1 = value + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + @auto_docstring def forward( self, @@ -833,9 +847,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): padded_mask_after_cnn.sum(1).cumsum(0), ) ).to(torch.int32) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) for encoder_layer in self.layers: - layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) @@ -928,12 +948,15 @@ class Qwen2_5OmniVisionAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.num_key_value_groups = 1 # needed for eager attention self.config = config + self.attention_dropout = 0.0 + self.is_causal = False def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -943,18 +966,9 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(query_states.dtype).min, - device=query_states.device, - dtype=query_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward @@ -966,13 +980,13 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -1009,10 +1023,15 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, + **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -1171,6 +1190,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): return window_index, cu_window_seqlens + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -1217,10 +1255,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens + + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, **kwargs, ) hidden_states = self.merger(hidden_states) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 9acc76c9afa..ae7681543e3 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1611,6 +1611,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): f" and `num_heads`: {self.num_heads})." ) self.scaling = self.head_dim**-0.5 + self.attention_dropout = 0.0 self.is_decoder = False self.is_causal = False @@ -1623,6 +1624,7 @@ class Qwen2_5OmniAudioAttention(nn.Module): self, hidden_states: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -1638,15 +1640,6 @@ class Qwen2_5OmniAudioAttention(nn.Module): value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attention_mask = torch.full( - [1, 1, seq_length, key_states.shape[-2]], - torch.finfo(query_states.dtype).min, - device=query_states.device, - dtype=query_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -1656,13 +1649,13 @@ class Qwen2_5OmniAudioAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -1682,6 +1675,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer): self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -1689,6 +1683,7 @@ class Qwen2_5OmniAudioEncoderLayer(Qwen2AudioEncoderLayer): hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, + attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states @@ -1770,6 +1765,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): def set_input_embeddings(self, value: nn.Module): self.conv1 = value + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + @auto_docstring def forward( self, @@ -1818,9 +1832,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): padded_mask_after_cnn.sum(1).cumsum(0), ) ).to(torch.int32) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) for encoder_layer in self.layers: - layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs) + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) @@ -1906,12 +1926,15 @@ class Qwen2_5OmniVisionAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.num_key_value_groups = 1 # needed for eager attention self.config = config + self.attention_dropout = 0.0 + self.is_causal = False def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -1921,18 +1944,9 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0) key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0) - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(query_states.dtype).min, - device=query_states.device, - dtype=query_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward @@ -1944,13 +1958,13 @@ class Qwen2_5OmniVisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -1970,10 +1984,15 @@ class Qwen2_5OmniVisionBlock(Qwen2_5_VLVisionBlock): hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, + **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -1987,6 +2006,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): super().__init__(config, *inputs, **kwargs) self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -2033,10 +2071,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens + + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, **kwargs, ) hidden_states = self.merger(hidden_states) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index ab318d955ff..f5b40ac4a45 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -206,6 +206,8 @@ class Qwen2_5_VLVisionAttention(nn.Module): self.proj = nn.Linear(self.dim, self.dim) self.scaling = self.head_dim**-0.5 self.config = config + self.attention_dropout = 0.0 + self.is_causal = False def forward( self, @@ -213,6 +215,7 @@ class Qwen2_5_VLVisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -233,18 +236,9 @@ class Qwen2_5_VLVisionAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(value_states.dtype).min, - device=value_states.device, - dtype=value_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward @@ -256,13 +250,13 @@ class Qwen2_5_VLVisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -286,6 +280,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -293,6 +288,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -426,6 +422,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return window_index, cu_window_seqlens + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -472,8 +487,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens + + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( - hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs + hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, ) hidden_states = self.merger(hidden_states) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 84a7a69ac81..30ddfbda098 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -159,6 +159,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -166,6 +167,7 @@ class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -287,6 +289,25 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): return window_index, cu_window_seqlens + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: """ Args: @@ -333,8 +354,14 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens + + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now) hidden_states = blk( - hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs + hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, ) hidden_states = self.merger(hidden_states) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index a799e7328e5..a075de666fb 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -324,6 +324,8 @@ class VisionAttention(nn.Module): self.proj = nn.Linear(self.dim, self.dim) self.scaling = self.head_dim**-0.5 self.config = config + self.attention_dropout = 0.0 + self.is_causal = False def forward( self, @@ -331,6 +333,7 @@ class VisionAttention(nn.Module): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] @@ -351,18 +354,9 @@ class VisionAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) - attention_mask = torch.full( - [1, 1, seq_length, seq_length], - torch.finfo(value_states.dtype).min, - device=value_states.device, - dtype=value_states.dtype, - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 - - query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim - value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attention_interface: Callable = eager_attention_forward @@ -374,13 +368,13 @@ class VisionAttention(nn.Module): query_states, key_states, value_states, - attention_mask, - dropout=0.0, + attention_mask=attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2 - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, + cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2 + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, is_causal=False, **kwargs, ) @@ -406,6 +400,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer): cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( @@ -413,6 +408,7 @@ class Qwen2VLVisionBlock(GradientCheckpointingLayer): cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, + attention_mask=attention_mask, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) @@ -725,6 +721,25 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb + def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: + # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` + # NOTE: the created attention masl only approximates the ragged FA2 attention by + # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between + # blocks. Though it will not be a 100% match for FA2's `varlen` path + if self.config._attn_implementation == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + return attention_mask + @auto_docstring def forward( self, @@ -750,10 +765,15 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens) for blk in self.blocks: hidden_states = blk( - hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + **kwargs, ) return self.merger(hidden_states) diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index 7894dc69806..9525f6415a2 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -184,7 +184,7 @@ class Qwen2_5_VLVisionText2TextModelTester: input_ids[:, self.num_image_tokens - 1] = self.vision_start_token_id inputs_dict = { "pixel_values": pixel_values, - "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size), + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device), "input_ids": input_ids, "attention_mask": attention_mask, } diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 72669fd390f..dff92fb5888 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -176,7 +176,7 @@ class Qwen2VLVisionText2TextModelTester: inputs_dict = { "pixel_values": pixel_values, - "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size), + "image_grid_thw": torch.tensor([[1, 1, 1]] * self.batch_size, device=torch_device), "input_ids": input_ids, "attention_mask": attention_mask, }