mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
🔴 [VLM] modeling updates (#38317)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* updates * fixup * fix tests * fix test * fix * let it be here for now, till monday * two more fixes * persimmon * fixup * fix * fixup * make sure fuyu runs now that LM has new attn API * fixup + tests * qwen vl uses new mask interface as well * qwen image features format * update * remove image_sizes * address comments * i am dumb...
This commit is contained in:
parent
a6f7acb603
commit
ad9dd3d17b
@ -673,7 +673,7 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
|
||||
_supports_attention_backend = False
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
@ -1299,7 +1299,7 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
config_class = AriaConfig
|
||||
base_model_prefix = ""
|
||||
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
|
||||
_supports_attention_backend = False
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
|
@ -282,7 +282,6 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, AyaVisionModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -310,7 +309,6 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
|
@ -24,12 +24,14 @@ from transformers.models.llava.modeling_llava import (
|
||||
LlavaCausalLMOutputWithPast,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaModel,
|
||||
LlavaModelOutputWithPast,
|
||||
LlavaPreTrainedModel,
|
||||
)
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import logging
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
||||
from .configuration_aya_vision import AyaVisionConfig
|
||||
|
||||
|
||||
@ -110,10 +112,154 @@ class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class AyaVisionModel(LlavaModel):
|
||||
class AyaVisionModelOutputWithPast(LlavaModelOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class AyaVisionModel(LlavaModel):
|
||||
# Unlike LLaVA, the model doesn't have to deal with Pixtral-style image states
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
||||
The tensors corresponding to the input images.
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
|
||||
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
else:
|
||||
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||
# For default; crop CLS from each hidden state in the hidden state pool
|
||||
if vision_feature_select_strategy == "default":
|
||||
hs_pool = [hs[:, 1:] for hs in hs_pool]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, AyaVisionModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return AyaVisionModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
def forward(
|
||||
self,
|
||||
|
@ -247,6 +247,7 @@ class ChameleonConfig(PretrainedConfig):
|
||||
self.vq_config = ChameleonVQVAEConfig(**vq_config)
|
||||
|
||||
self.vocabulary_map = vocabulary_map
|
||||
self.image_token_id = vocabulary_map.get("<image>") if vocabulary_map is not None else None
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
|
@ -904,12 +904,6 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
self.embed_tokens = value
|
||||
|
||||
def get_image_tokens(self, pixel_values: torch.FloatTensor):
|
||||
logger.warning(
|
||||
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||
)
|
||||
return self.get_image_featues(pixel_values)
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||
@ -925,6 +919,19 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
bpe_toks = bpe_toks.view(batch_size, -1)
|
||||
return bpe_toks
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module and embeds
|
||||
them with text embeddings layer
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
vision_embeddings = self.get_input_embeddings()(image_tokens)
|
||||
return vision_embeddings
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -963,7 +970,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_features(pixel_values)
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
||||
|
@ -320,6 +320,7 @@ class Emu3Config(PretrainedConfig):
|
||||
self.vq_config = vq_config
|
||||
self.text_config = text_config
|
||||
self.vocabulary_map = vocabulary_map
|
||||
self.image_token_id = vocabulary_map.get("<image>") if vocabulary_map is not None else None
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
@ -1451,12 +1451,6 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
self.text_model.set_input_embeddings(value)
|
||||
|
||||
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
logger.warning(
|
||||
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||
)
|
||||
return self.get_image_featues(pixel_values)
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||
@ -1473,6 +1467,24 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
bpe_tokens = torch.cat(bpe_tokens_list)
|
||||
return bpe_tokens
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module and embeds
|
||||
them with text embeddings layer
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
split_sizes = [
|
||||
(height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1)
|
||||
for height, width in image_sizes
|
||||
]
|
||||
image_features = self.get_input_embeddings()(image_tokens)
|
||||
image_features = torch.split(image_features, split_sizes)
|
||||
return image_features
|
||||
|
||||
@torch.no_grad
|
||||
def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
|
||||
"""
|
||||
@ -1533,7 +1545,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_features(pixel_values, image_sizes)
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
|
@ -938,12 +938,6 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
self.text_model.set_input_embeddings(value)
|
||||
|
||||
def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
logger.warning(
|
||||
"`model.get_image_tokens()` is deprecated and will be removed in v4.58. To obtain discrete token use `model.get_image_features()`"
|
||||
)
|
||||
return self.get_image_featues(pixel_values)
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||
@ -960,6 +954,24 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
bpe_tokens = torch.cat(bpe_tokens_list)
|
||||
return bpe_tokens
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
|
||||
"""
|
||||
Tokenizes images into discrete tokens with VQGAN module and embeds
|
||||
them with text embeddings layer
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images.
|
||||
"""
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
split_sizes = [
|
||||
(height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1)
|
||||
for height, width in image_sizes
|
||||
]
|
||||
image_features = self.get_input_embeddings()(image_tokens)
|
||||
image_features = torch.split(image_features, split_sizes)
|
||||
return image_features
|
||||
|
||||
@torch.no_grad
|
||||
def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
|
||||
"""
|
||||
@ -1020,7 +1032,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_features(pixel_values, image_sizes)
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -89,6 +89,8 @@ class FuyuConfig(PretrainedConfig):
|
||||
The id of the *beginning-of-sequence* token.
|
||||
eos_token_id (`Union[int, List[int]]`, *optional*, defaults to 2):
|
||||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
||||
image_token_id (`int`, *optional*, defaults to 71011):
|
||||
The id of the image placeholder token.
|
||||
text_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize the `language``[`Aut`].
|
||||
|
||||
@ -100,6 +102,7 @@ class FuyuConfig(PretrainedConfig):
|
||||
```"""
|
||||
|
||||
model_type = "fuyu"
|
||||
sub_configs = {"text_config": AutoConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
@ -127,6 +130,7 @@ class FuyuConfig(PretrainedConfig):
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
image_token_id=71011,
|
||||
text_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -176,6 +180,7 @@ class FuyuConfig(PretrainedConfig):
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self.image_token_id = image_token_id
|
||||
self._rope_scaling_validation()
|
||||
|
||||
super().__init__(
|
||||
|
@ -202,11 +202,12 @@ class FuyuModel(FuyuPreTrainedModel):
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if image_patches is not None and past_key_values is None:
|
||||
patch_embeddings = self.get_image_features(image_patches)
|
||||
inputs_embeds = self.gather_continuous_embeddings(
|
||||
word_embeddings=inputs_embeds,
|
||||
continuous_embeddings=patch_embeddings,
|
||||
image_patch_input_indices=image_patches_indices,
|
||||
)
|
||||
patch_embeddings = torch.cat(patch_embeddings, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
|
||||
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -620,12 +620,16 @@ class FuyuProcessor(ProcessorMixin):
|
||||
width_scale_factor = padded_width / image_size[1]
|
||||
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||
|
||||
# We can use torch here because Fuyu processor has hard dependency on torch
|
||||
image_unpadded_h = min(int(image_size[0] * optimal_scale_factor), image_size[0])
|
||||
image_unpadded_w = min(int(image_size[0] * optimal_scale_factor), image_size[0])
|
||||
|
||||
# We can use torch here because Fuyu processor has hard dependency on torch. NOTE: Fuyu can't do multi-image
|
||||
# thus the below (1, 1, 1) is hardcoded. Same as when calling the processor
|
||||
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
|
||||
image_input=torch.zeros(1, 1, 3, padded_height, padded_width),
|
||||
image_present=torch.ones(1, 1, 1),
|
||||
image_unpadded_h=torch.tensor([[int(image_size[0] * optimal_scale_factor)]]),
|
||||
image_unpadded_w=torch.tensor([[int(image_size[1] * optimal_scale_factor)]]),
|
||||
image_unpadded_h=torch.tensor([[image_unpadded_h]]),
|
||||
image_unpadded_w=torch.tensor([[image_unpadded_w]]),
|
||||
image_placeholder_id=0, # dummy ids, we can be sure `id=0` is never out-of-range
|
||||
image_newline_id=0,
|
||||
variable_sized=True,
|
||||
|
@ -377,7 +377,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs, # NOOP kwarg for now
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -446,6 +446,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -23,7 +23,9 @@ from torch import nn
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...processing_utils import Unpack
|
||||
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
from ..llama.modeling_llama import (
|
||||
@ -378,7 +380,7 @@ class GemmaModel(LlamaModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs, # NOOP kwarg for now
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -447,6 +449,7 @@ class GemmaModel(LlamaModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -693,8 +693,10 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||||
"""
|
||||
special_image_token_mask = input_ids == self.image_token_id
|
||||
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
|
||||
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
|
||||
new_inputs_embeds = inputs_embeds.clone()
|
||||
# Flatten `image_hidden_states` if not flat yet
|
||||
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
# cast to the dtype of the input_embeds to support quantized models
|
||||
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
new_inputs_embeds[special_image_token_mask] = image_hidden_states
|
||||
@ -742,7 +744,6 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(image_hidden_states.last_hidden_state)
|
||||
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
return image_hidden_states
|
||||
|
||||
@can_return_tuple
|
||||
@ -807,9 +808,6 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
past_key_values = DynamicCache()
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
||||
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
|
||||
|
||||
@ -821,7 +819,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
|
||||
if past_seen_tokens == 0 and input_ids is not None and image_hidden_states is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self.inputs_merger(
|
||||
|
@ -627,8 +627,8 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -642,6 +642,15 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
Returns:
|
||||
vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
downsample_ratio = self.config.downsample_ratio
|
||||
if vision_feature_layer == -1:
|
||||
vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||
@ -666,7 +675,6 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
|
||||
# Project features through multi-modal projector
|
||||
vision_features = self.multi_modal_projector(vision_features)
|
||||
|
||||
return vision_features
|
||||
|
||||
@can_return_tuple
|
||||
@ -686,7 +694,6 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
image_sizes: torch.Tensor = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, InternVLModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -714,7 +721,6 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
|
@ -27,7 +27,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging, torch_int
|
||||
from ..clip.modeling_clip import CLIPMLP
|
||||
from ..janus.modeling_janus import JanusVisionAttention
|
||||
from ..llama.modeling_llama import LlamaRMSNorm
|
||||
@ -35,6 +35,7 @@ from ..llava.modeling_llava import (
|
||||
LlavaCausalLMOutputWithPast,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaModel,
|
||||
LlavaModelOutputWithPast,
|
||||
LlavaPreTrainedModel,
|
||||
)
|
||||
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
|
||||
@ -510,6 +511,10 @@ class InternVLMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternVLModelOutputWithPast(LlavaModelOutputWithPast):
|
||||
pass
|
||||
|
||||
|
||||
class InternVLModel(LlavaModel):
|
||||
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
|
||||
"""Perform pixel shuffle downsampling on vision features.
|
||||
@ -549,8 +554,8 @@ class InternVLModel(LlavaModel):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -564,6 +569,15 @@ class InternVLModel(LlavaModel):
|
||||
Returns:
|
||||
vision_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
downsample_ratio = self.config.downsample_ratio
|
||||
if vision_feature_layer == -1:
|
||||
vision_features = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||
@ -588,9 +602,94 @@ class InternVLModel(LlavaModel):
|
||||
|
||||
# Project features through multi-modal projector
|
||||
vision_features = self.multi_modal_projector(vision_features)
|
||||
|
||||
return vision_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, InternVLModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return InternVLModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||
pass
|
||||
|
@ -233,6 +233,15 @@ class LlavaModel(LlavaPreTrainedModel):
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
if "image_sizes" in kwargs:
|
||||
split_sizes = [
|
||||
(height // self.vision_tower.patch_size) * (width // self.vision_tower.patch_size)
|
||||
for height, width in kwargs["image_sizes"]
|
||||
]
|
||||
image_features = torch.split(image_features.squeeze(0), split_sizes)
|
||||
else:
|
||||
image_features = list(image_features)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@ -282,6 +291,7 @@ class LlavaModel(LlavaPreTrainedModel):
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
|
@ -357,9 +357,8 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||
return image_features, feature_lens
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
|
||||
return new_image_features, feature_lens
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
@ -429,6 +428,14 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@ -489,14 +496,7 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
@ -411,9 +411,8 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||
return image_features, feature_lens
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
|
||||
return new_image_features, feature_lens
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
@ -482,6 +481,13 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@ -544,12 +550,7 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||
)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
self.vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
@ -315,6 +315,13 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
return image_features
|
||||
|
||||
def get_video_features(
|
||||
@ -430,12 +437,7 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
vision_feature_layer=self.vision_feature_layer,
|
||||
vision_feature_select_strategy=self.vision_feature_select_strategy,
|
||||
)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
self.vision_feature_select_strategy,
|
||||
image_newline=self.image_newline,
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
@ -409,18 +409,19 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
image_feature = image_feature[0]
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
|
||||
image_feature = image_feature.flatten(0, 1)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||
return image_features, feature_lens
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
|
||||
return new_image_features, feature_lens
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
vision_aspect_ratio: Optional[str] = None,
|
||||
batch_num_images: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
@ -444,6 +445,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
vision_aspect_ratio = (
|
||||
vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
|
||||
)
|
||||
|
||||
# ! infer image_num_patches from image_sizes
|
||||
if batch_num_images is None:
|
||||
# treat this as a single-image case for backward compatibility
|
||||
@ -483,6 +496,13 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
vision_aspect_ratio=vision_aspect_ratio,
|
||||
)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@ -564,12 +584,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
batch_num_images=batch_num_images,
|
||||
)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
vision_aspect_ratio=vision_aspect_ratio,
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
@ -316,11 +316,11 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
image_feature = image_feature[0]
|
||||
if image_newline is not None:
|
||||
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
|
||||
image_feature = image_feature.flatten(0, 1)
|
||||
new_image_features.append(image_feature)
|
||||
feature_lens.append(image_feature.size(0))
|
||||
image_features = torch.cat(new_image_features, dim=0)
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
|
||||
return image_features, feature_lens
|
||||
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features[0].device)
|
||||
return new_image_features, feature_lens
|
||||
|
||||
def apply_pooling(self, image_features):
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
@ -340,8 +340,9 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
vision_aspect_ratio: Optional[str] = None,
|
||||
batch_num_images: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
@ -365,6 +366,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
vision_aspect_ratio = (
|
||||
vision_aspect_ratio if vision_aspect_ratio is not None else self.config.vision_aspect_ratio
|
||||
)
|
||||
|
||||
# ! infer image_num_patches from image_sizes
|
||||
if batch_num_images is None:
|
||||
# treat this as a single-image case for backward compatibility
|
||||
@ -404,6 +417,13 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
vision_aspect_ratio=vision_aspect_ratio,
|
||||
)
|
||||
return image_features
|
||||
|
||||
def get_video_features(
|
||||
@ -529,12 +549,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
batch_num_images=batch_num_images,
|
||||
)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
image_sizes,
|
||||
image_newline=self.image_newline,
|
||||
vision_aspect_ratio=vision_aspect_ratio,
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
@ -285,6 +285,9 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||
downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
|
||||
split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes]
|
||||
image_features = torch.split(image_features.squeeze(0), split_sizes)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@ -332,6 +335,7 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
@ -170,6 +170,9 @@ class Mistral3Model(LlavaModel):
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
||||
downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
|
||||
split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes]
|
||||
image_features = torch.split(image_features.squeeze(0), split_sizes)
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
@ -215,6 +218,7 @@ class Mistral3Model(LlavaModel):
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
@ -19,8 +19,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Persimmon model."""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -30,6 +29,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@ -37,7 +37,8 @@ from ...modeling_outputs import (
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from .configuration_persimmon import PersimmonConfig
|
||||
|
||||
@ -137,6 +138,29 @@ class PersimmonMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class PersimmonAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -166,6 +190,7 @@ class PersimmonAttention(nn.Module):
|
||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
||||
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
|
||||
self.qk_layernorm = config.qk_layernorm
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
if self.qk_layernorm:
|
||||
self.q_layernorm = nn.LayerNorm(
|
||||
@ -203,6 +228,7 @@ class PersimmonAttention(nn.Module):
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -249,27 +275,22 @@ class PersimmonAttention(nn.Module):
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
|
||||
attn_weights = self.attention_dropout(attn_weights)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.config.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
@ -298,6 +319,7 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -337,6 +359,7 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -369,6 +392,9 @@ class PersimmonPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -430,6 +456,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -512,6 +539,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -758,6 +786,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
@ -257,7 +257,7 @@ class PixtralProcessor(ProcessorMixin):
|
||||
num_image_tokens = []
|
||||
for height, width in image_sizes:
|
||||
resized_height, resized_width = get_resize_output_image_size(
|
||||
image=np.zeros((height, width, 3)),
|
||||
np.zeros((height, width, 3)),
|
||||
size=(size["longest_edge"], size["longest_edge"]),
|
||||
patch_size=(patch_size, patch_size),
|
||||
)
|
||||
|
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
@ -257,6 +257,8 @@ class Qwen2_5OmniTextConfig(PretrainedConfig):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
@ -358,6 +360,7 @@ class Qwen2_5OmniTextConfig(PretrainedConfig):
|
||||
use_sliding_window=False,
|
||||
sliding_window=32768,
|
||||
max_window_layers=28,
|
||||
layer_types=None,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
@ -396,6 +399,16 @@ class Qwen2_5OmniTextConfig(PretrainedConfig):
|
||||
if self.rope_scaling is None:
|
||||
self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"}
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention"
|
||||
if self.sliding_window is not None and i >= self.max_window_layers
|
||||
else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
|
||||
class Qwen2_5OmniThinkerConfig(PretrainedConfig):
|
||||
r"""
|
||||
|
@ -22,7 +22,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -31,19 +31,20 @@ from torch import nn
|
||||
from torch.nn import Parameter
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
check_torch_load_is_safe,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
)
|
||||
from ...utils.hub import cached_file
|
||||
@ -68,15 +69,6 @@ else:
|
||||
apply_rotary_emb = None
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -110,7 +102,8 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only
|
||||
@ -1444,6 +1437,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2_5OmniAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
@ -1469,11 +1488,13 @@ class Qwen2_5OmniAttention(nn.Module):
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.rope_scaling = config.rope_scaling
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
|
||||
|
||||
@ -1487,6 +1508,7 @@ class Qwen2_5OmniAttention(nn.Module):
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -1507,40 +1529,24 @@ class Qwen2_5OmniAttention(nn.Module):
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Fix precision issues in Qwen2-VL float16 inference
|
||||
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||
if query_states.dtype == torch.float16:
|
||||
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
@ -1558,216 +1564,7 @@ class Qwen2MLP(nn.Module):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||
|
||||
|
||||
class Qwen2_5OmniFlashAttention2(Qwen2_5OmniAttention):
|
||||
"""
|
||||
Qwen2_5Omni flash attention module, following Qwen2_5Omni attention module. This module inherits from `Qwen2_5OmniAttention`
|
||||
as the weights of the module stays untouched. The only required change would be on the forward pass
|
||||
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
||||
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
||||
config.max_window_layers layers.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
sliding_window = self.config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Qwen2_5OmniSdpaAttention(Qwen2_5OmniAttention):
|
||||
"""
|
||||
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from Qwen2Attention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Qwen2_5OmniModel is using Qwen2_5OmniSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
QWEN2_5_OMNI_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2_5OmniAttention,
|
||||
"flash_attention_2": Qwen2_5OmniFlashAttention2,
|
||||
"sdpa": Qwen2_5OmniSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5OmniDecoderLayer(nn.Module):
|
||||
class Qwen2_5OmniDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -1777,11 +1574,12 @@ class Qwen2_5OmniDecoderLayer(nn.Module):
|
||||
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||
"unexpected results may be encountered."
|
||||
)
|
||||
self.self_attn = QWEN2_5_OMNI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||
self.self_attn = Qwen2_5OmniAttention(config, layer_idx)
|
||||
|
||||
self.mlp = Qwen2MLP(config)
|
||||
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1793,7 +1591,7 @@ class Qwen2_5OmniDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -1831,6 +1629,7 @@ class Qwen2_5OmniDecoderLayer(nn.Module):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -1868,6 +1667,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
|
||||
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -1892,6 +1692,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1930,9 +1731,23 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
elif position_ids.dim() == 2:
|
||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
}
|
||||
# The sliding window alternating layers are not always activated depending on the config
|
||||
if self.has_sliding_layers:
|
||||
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -1952,7 +1767,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@ -1963,13 +1778,14 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -1997,161 +1813,6 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinker. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: Qwen2_5OmniConfig,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`Qwen25OmniThinkerConfig`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -2398,9 +2059,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
else:
|
||||
@ -2574,6 +2232,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
|
||||
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -2598,6 +2257,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -2636,9 +2296,23 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
elif position_ids.dim() == 2:
|
||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
}
|
||||
# The sliding window alternating layers are not always activated depending on the config
|
||||
if self.has_sliding_layers:
|
||||
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -2658,7 +2332,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@ -2669,13 +2343,14 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -2703,161 +2378,6 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen2_5Omni. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: Qwen2_5OmniConfig,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`Qwen2_5OmniConfig`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin):
|
||||
config_class = Qwen2_5OmniTalkerConfig
|
||||
|
@ -41,7 +41,7 @@ from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2Audio
|
||||
from transformers.models.qwen2_audio.modeling_qwen2_audio import Qwen2AudioEncoderLayer
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
@ -298,6 +298,8 @@ class Qwen2_5OmniTextConfig(PretrainedConfig):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
@ -399,6 +401,7 @@ class Qwen2_5OmniTextConfig(PretrainedConfig):
|
||||
use_sliding_window=False,
|
||||
sliding_window=32768,
|
||||
max_window_layers=28,
|
||||
layer_types=None,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
@ -437,6 +440,16 @@ class Qwen2_5OmniTextConfig(PretrainedConfig):
|
||||
if self.rope_scaling is None:
|
||||
self.rope_scaling = {"mrope_section": [16, 24, 24], "rope_type": "default", "type": "default"}
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention"
|
||||
if self.sliding_window is not None and i >= self.max_window_layers
|
||||
else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
|
||||
class Qwen2_5OmniThinkerConfig(PretrainedConfig):
|
||||
r"""
|
||||
@ -1104,7 +1117,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
|
||||
|
||||
class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
config_class = Qwen2_5OmniConfig
|
||||
_supports_static_cache = True
|
||||
_supports_static_cache = False
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only
|
||||
@ -2184,11 +2197,13 @@ class Qwen2_5OmniAttention(Qwen2_5_VLAttention, nn.Module):
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.rope_scaling = config.rope_scaling
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config)
|
||||
|
||||
@ -2450,9 +2465,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
else:
|
||||
|
@ -23,7 +23,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
@ -117,6 +117,8 @@ class Qwen2_5_VLTextConfig(PretrainedConfig):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 80):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
@ -211,6 +213,7 @@ class Qwen2_5_VLTextConfig(PretrainedConfig):
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
layer_types=None,
|
||||
attention_dropout=0.0,
|
||||
rope_scaling=None,
|
||||
image_token_id=None,
|
||||
@ -224,7 +227,7 @@ class Qwen2_5_VLTextConfig(PretrainedConfig):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.sliding_window = sliding_window if self.use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
@ -240,6 +243,16 @@ class Qwen2_5_VLTextConfig(PretrainedConfig):
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention"
|
||||
if self.sliding_window is not None and i >= self.max_window_layers
|
||||
else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||
|
@ -26,21 +26,23 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
|
||||
|
||||
|
||||
@ -48,15 +50,6 @@ if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -359,6 +352,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.get_text_config().initializer_range
|
||||
@ -681,6 +675,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2_5_VLAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
@ -706,6 +726,7 @@ class Qwen2_5_VLAttention(nn.Module):
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.rope_scaling = config.rope_scaling
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
@ -716,6 +737,7 @@ class Qwen2_5_VLAttention(nn.Module):
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
|
||||
|
||||
@ -729,6 +751,7 @@ class Qwen2_5_VLAttention(nn.Module):
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -749,253 +772,28 @@ class Qwen2_5_VLAttention(nn.Module):
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Fix precision issues in Qwen2-VL float16 inference
|
||||
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||
if query_states.dtype == torch.float16:
|
||||
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
"""
|
||||
Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention`
|
||||
as the weights of the module stays untouched. The only required change would be on the forward pass
|
||||
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
||||
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
||||
config.max_window_layers layers.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
sliding_window = self.config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention):
|
||||
"""
|
||||
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from Qwen2Attention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Qwen2_5_VLModel is using Qwen2_5_VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
QWEN2_5_VL_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2_5_VLAttention,
|
||||
"flash_attention_2": Qwen2_5_VLFlashAttention2,
|
||||
"sdpa": Qwen2_5_VLSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||
class Qwen2_5_VLDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -1005,11 +803,12 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||
"unexpected results may be encountered."
|
||||
)
|
||||
self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
|
||||
|
||||
self.mlp = Qwen2MLP(config)
|
||||
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1021,7 +820,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -1059,6 +858,7 @@ class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -1095,6 +895,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
|
||||
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -1119,6 +920,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1157,9 +959,23 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
elif position_ids.dim() == 2:
|
||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
}
|
||||
# The sliding window alternating layers are not always activated depending on the config
|
||||
if self.has_sliding_layers:
|
||||
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -1179,7 +995,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@ -1190,13 +1006,14 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -1224,160 +1041,8 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: Qwen2_5_VLConfig,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`Qwen2_5_VLConfig`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -1598,6 +1263,8 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
"""
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
||||
video_embeds = torch.split(video_embeds, split_sizes)
|
||||
return video_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||
@ -1612,6 +1279,8 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
"""
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
||||
image_embeds = torch.split(image_embeds, split_sizes)
|
||||
return image_embeds
|
||||
|
||||
@auto_docstring
|
||||
@ -1633,6 +1302,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]:
|
||||
r"""
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
||||
@ -1659,6 +1329,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
@ -1676,6 +1347,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
@ -1691,15 +1363,14 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_2d = attention_mask
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
|
||||
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
||||
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
@ -1719,7 +1390,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
attention_mask=attention_mask_2d,
|
||||
attention_mask=attention_mask_tensor,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
@ -1748,6 +1419,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = Qwen2_5_VLModelOutputWithPast(
|
||||
@ -1759,61 +1431,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
|
||||
@ -1916,6 +1533,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1988,6 +1606,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -30,6 +30,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
KwargsForCausalLM,
|
||||
PatchEmbed,
|
||||
PatchMerger,
|
||||
Qwen2RMSNorm,
|
||||
@ -622,6 +623,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Qwen2_5_VLModelOutputWithPast]:
|
||||
r"""
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
||||
@ -648,6 +650,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
@ -665,6 +668,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
@ -680,15 +684,14 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_2d = attention_mask
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
|
||||
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
||||
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
@ -708,7 +711,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
attention_mask=attention_mask_2d,
|
||||
attention_mask=attention_mask_tensor,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
@ -737,6 +740,7 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = Qwen2_5_VLModelOutputWithPast(
|
||||
@ -774,6 +778,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -846,6 +851,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Qwen2VL model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
@ -106,6 +106,8 @@ class Qwen2VLTextConfig(PretrainedConfig):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 80):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
@ -200,6 +202,7 @@ class Qwen2VLTextConfig(PretrainedConfig):
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
layer_types=None,
|
||||
attention_dropout=0.0,
|
||||
rope_scaling=None,
|
||||
image_token_id=None,
|
||||
@ -213,7 +216,7 @@ class Qwen2VLTextConfig(PretrainedConfig):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.sliding_window = sliding_window if self.use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
@ -229,6 +232,16 @@ class Qwen2VLTextConfig(PretrainedConfig):
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention"
|
||||
if self.sliding_window is not None and i >= self.max_window_layers
|
||||
else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
||||
|
@ -21,7 +21,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -30,24 +30,27 @@ import torch.utils.checkpoint
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_varlen_func
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
from ...modeling_flash_attention_utils import flash_attn_varlen_func
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -516,6 +519,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2VLAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
@ -541,6 +570,7 @@ class Qwen2VLAttention(nn.Module):
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.rope_scaling = config.rope_scaling
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
@ -551,6 +581,7 @@ class Qwen2VLAttention(nn.Module):
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
|
||||
|
||||
@ -564,6 +595,7 @@ class Qwen2VLAttention(nn.Module):
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -584,253 +616,28 @@ class Qwen2VLAttention(nn.Module):
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Fix precision issues in Qwen2-VL float16 inference
|
||||
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||
if query_states.dtype == torch.float16:
|
||||
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
||||
"""
|
||||
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
|
||||
as the weights of the module stays untouched. The only required change would be on the forward pass
|
||||
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
||||
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
||||
config.max_window_layers layers.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if (
|
||||
self.config.use_sliding_window
|
||||
and getattr(self.config, "sliding_window", None) is not None
|
||||
and self.layer_idx >= self.config.max_window_layers
|
||||
):
|
||||
sliding_window = self.config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=sliding_window,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
||||
"""
|
||||
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from Qwen2Attention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
QWEN2_VL_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2VLAttention,
|
||||
"flash_attention_2": Qwen2VLFlashAttention2,
|
||||
"sdpa": Qwen2VLSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2VLDecoderLayer(nn.Module):
|
||||
class Qwen2VLDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Qwen2VLTextConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -840,11 +647,12 @@ class Qwen2VLDecoderLayer(nn.Module):
|
||||
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
||||
"unexpected results may be encountered."
|
||||
)
|
||||
self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||
self.self_attn = Qwen2VLAttention(config, layer_idx)
|
||||
|
||||
self.mlp = Qwen2MLP(config)
|
||||
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -856,7 +664,7 @@ class Qwen2VLDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -894,6 +702,7 @@ class Qwen2VLDecoderLayer(nn.Module):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -925,6 +734,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.get_text_config().initializer_range
|
||||
@ -1053,6 +863,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
|
||||
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -1077,6 +888,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1115,9 +927,23 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
elif position_ids.dim() == 2:
|
||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
}
|
||||
# The sliding window alternating layers are not always activated depending on the config
|
||||
if self.has_sliding_layers:
|
||||
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@ -1137,7 +963,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
@ -1148,13 +974,14 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@ -1182,162 +1009,8 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._update_causal_mask with Phimoe->Qwen2VL
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and past_key_values is not None:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||
if is_padding_right:
|
||||
raise ValueError(
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Qwen2VL. Make sure to "
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and not (using_static_cache or using_sliding_window_cache)
|
||||
and not output_attentions
|
||||
):
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
sliding_window=self.config.sliding_window,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
# SlidingWindowCache or StaticCache
|
||||
if using_sliding_window_cache or using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
# DynamicCache or no cache
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
config=self.config,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.phimoe.modeling_phimoe.PhimoeModel._prepare_4d_causal_attention_mask_with_cache_position with Phimoe->Qwen2VL
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
config: Qwen2VLConfig,
|
||||
past_key_values: Cache,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
config (`Qwen2VLConfig`):
|
||||
The model's configuration class
|
||||
past_key_values (`Cache`):
|
||||
The cache class that is being used currently to generate
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
|
||||
-1, 1
|
||||
)
|
||||
text_config = config.get_text_config()
|
||||
if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
|
||||
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||
sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
|
||||
cache_position.reshape(-1, 1) - text_config.sliding_window
|
||||
)
|
||||
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
|
||||
causal_mask *= diagonal_attend_mask
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.shape[-1] > target_length:
|
||||
attention_mask = attention_mask[:, :target_length]
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
return causal_mask
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -1523,6 +1196,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
"""
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
||||
video_embeds = torch.split(video_embeds, split_sizes)
|
||||
return video_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
||||
@ -1537,6 +1212,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
"""
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
||||
image_embeds = torch.split(image_embeds, split_sizes)
|
||||
return image_embeds
|
||||
|
||||
@auto_docstring
|
||||
@ -1557,6 +1234,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Qwen2VLModelOutputWithPast]:
|
||||
r"""
|
||||
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
|
||||
@ -1581,6 +1259,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
@ -1598,6 +1277,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
@ -1613,15 +1293,14 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_2d = attention_mask
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
|
||||
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
||||
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
@ -1637,7 +1316,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
)
|
||||
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask_2d
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask_tensor
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
@ -1663,6 +1342,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
output = Qwen2VLModelOutputWithPast(
|
||||
@ -1674,62 +1354,6 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
@ -1792,6 +1416,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
@ -1861,6 +1486,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -181,7 +181,9 @@ class VipLlavaModel(VipLlavaPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]):
|
||||
def get_image_features(
|
||||
self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, List[int]]] = None
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
@ -194,6 +196,9 @@ class VipLlavaModel(VipLlavaPreTrainedModel):
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layers = (
|
||||
vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
|
||||
)
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
|
||||
# If multiple feature layers are provided (which is usually the case)
|
||||
|
@ -71,7 +71,9 @@ class VipLlavaPreTrainedModel(LlavaPreTrainedModel):
|
||||
|
||||
|
||||
class VipLlavaModel(LlavaModel):
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: Union[int, List[int]]):
|
||||
def get_image_features(
|
||||
self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, List[int]]] = None
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
@ -84,6 +86,9 @@ class VipLlavaModel(LlavaModel):
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layers = (
|
||||
vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
|
||||
)
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
|
||||
# If multiple feature layers are provided (which is usually the case)
|
||||
|
@ -326,9 +326,7 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Emu3Vision2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=Emu3Config, has_text_modality=False, common_properties=["vocabulary_map"]
|
||||
)
|
||||
self.config_tester = ConfigTester(self, config_class=Emu3Config, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
@ -46,7 +46,10 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
],
|
||||
"Qwen2Config": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen2MoeConfig": ["use_sliding_window"],
|
||||
"Qwen2VLConfig": ["use_sliding_window"],
|
||||
"Qwen2VLTextConfig": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen2_5_VLTextConfig": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen2_5OmniTextConfig": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen2_5OmniTalkerConfig": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen3Config": ["max_window_layers", "use_sliding_window"], # now use `layer_types` instead
|
||||
"Qwen3MoeConfig": ["max_window_layers", "use_sliding_window"],
|
||||
# `cache_implementation` should be in the default generation config, but we don't yet support per-model
|
||||
|
Loading…
Reference in New Issue
Block a user