mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Optimize Qwen2VL vision model by precomputing cos/sin embeds before ViT blocks (#35837)
* Optimize Qwen2VL vision model by precomputing cos/sin embeds before ViT blocks * Make rotary_pos_emb optional & fix type * Adapt pre-computed cos/sin to Qwen2.5VL * More concise
This commit is contained in:
parent
d72642bccc
commit
5f0fd1185b
@ -160,12 +160,14 @@ class Qwen2_5_VLPatchMerger(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
tensor_ = tensor.float()
|
||||
cos = freqs.cos().float()
|
||||
sin = freqs.sin().float()
|
||||
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
||||
return output
|
||||
def apply_rotary_pos_emb_flashatt(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
q_embed = apply_rotary_emb(q.float(), cos, sin).type_as(q)
|
||||
k_embed = apply_rotary_emb(k.float(), cos, sin).type_as(k)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionFlashAttention2(nn.Module):
|
||||
@ -179,12 +181,26 @@ class Qwen2_5_VLVisionFlashAttention2(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().float()
|
||||
sin = emb.sin().float()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
||||
q = q.squeeze(0)
|
||||
k = k.squeeze(0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||||
@ -201,16 +217,18 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
output = output.to(orig_dtype)
|
||||
return output
|
||||
def apply_rotary_pos_emb_vision(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_q_dtype = q.dtype
|
||||
orig_k_dtype = k.dtype
|
||||
q, k = q.float(), k.float()
|
||||
cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
q_embed = q_embed.to(orig_q_dtype)
|
||||
k_embed = k_embed.to(orig_k_dtype)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
@ -222,12 +240,27 @@ class Qwen2_5_VLVisionAttention(nn.Module):
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().float()
|
||||
sin = emb.sin().float()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
@ -256,12 +289,27 @@ class Qwen2_5_VLVisionSdpaAttention(nn.Module):
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().float()
|
||||
sin = emb.sin().float()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
@ -293,11 +341,18 @@ class Qwen2_5_VLVisionBlock(nn.Module):
|
||||
)
|
||||
self.mlp = Qwen2_5_VLMLP(config, bias=True)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -477,6 +532,8 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
dim=0,
|
||||
@ -495,14 +552,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb
|
||||
blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings
|
||||
)
|
||||
else:
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
)
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
|
@ -51,7 +51,7 @@ from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, VideoInput
|
||||
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import is_flash_attn_2_available, is_torchdynamo_compiling
|
||||
from ...utils import is_flash_attn_2_available, is_torchdynamo_compiling, logging
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
@ -63,12 +63,17 @@ else:
|
||||
apply_rotary_emb = None
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_flashatt(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
tensor_ = tensor.float()
|
||||
cos = freqs.cos().float()
|
||||
sin = freqs.sin().float()
|
||||
output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor)
|
||||
return output
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_flashatt(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
q_embed = apply_rotary_emb(q.float(), cos, sin).type_as(q)
|
||||
k_embed = apply_rotary_emb(k.float(), cos, sin).type_as(k)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
@ -153,12 +158,26 @@ class Qwen2_5_VLVisionFlashAttention2(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
q = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().float()
|
||||
sin = emb.sin().float()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
|
||||
q = q.squeeze(0)
|
||||
k = k.squeeze(0)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||||
@ -193,11 +212,18 @@ class Qwen2_5_VLVisionBlock(nn.Module):
|
||||
)
|
||||
self.mlp = Qwen2_5_VLMLP(config, bias=True)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -337,6 +363,8 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
dim=0,
|
||||
@ -355,14 +383,10 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb
|
||||
blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings
|
||||
)
|
||||
else:
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
)
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
|
||||
|
||||
hidden_states = self.merger(hidden_states)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
|
@ -214,16 +214,18 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
output = output.to(orig_dtype)
|
||||
return output
|
||||
def apply_rotary_pos_emb_vision(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_q_dtype = q.dtype
|
||||
orig_k_dtype = k.dtype
|
||||
q, k = q.float(), k.float()
|
||||
cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
q_embed = q_embed.to(orig_q_dtype)
|
||||
k_embed = k_embed.to(orig_k_dtype)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class VisionRotaryEmbedding(nn.Module):
|
||||
@ -300,12 +302,27 @@ class VisionAttention(nn.Module):
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().float()
|
||||
sin = emb.sin().float()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||
@ -334,12 +351,27 @@ class VisionFlashAttention2(nn.Module):
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().float()
|
||||
sin = emb.sin().float()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
||||
@ -357,12 +389,27 @@ class VisionSdpaAttention(nn.Module):
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos().float()
|
||||
sin = emb.sin().float()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
@ -396,9 +443,18 @@ class Qwen2VLVisionBlock(nn.Module):
|
||||
)
|
||||
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -961,6 +1017,8 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
dim=0,
|
||||
@ -975,10 +1033,10 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
for blk in self.blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb
|
||||
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
|
||||
)
|
||||
else:
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||
|
||||
return self.merger(hidden_states)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user